From da154145356b47afc896793ae93d69ed69c07574 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Thu, 25 Jun 2026 22:30:31 +0100 Subject: [PATCH 1/8] start adding chord comparison logic --- notebooks/Chord_comparison.ipynb | 182 +++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 notebooks/Chord_comparison.ipynb diff --git a/notebooks/Chord_comparison.ipynb b/notebooks/Chord_comparison.ipynb new file mode 100644 index 0000000..6240deb --- /dev/null +++ b/notebooks/Chord_comparison.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a3d9e14b", + "metadata": {}, + "source": [ + "# Chord Comparison" + ] + }, + { + "cell_type": "markdown", + "id": "3eac7376", + "metadata": {}, + "source": [ + "A \"chord\" = multiple(>=3) notes with the same or very close start times\n", + "- Block Chords: multiple(>=3) notes with the same start time\n", + "- Broken Chords / Arpeggios: multiple(>=3) notes play one after another, in a given order\n", + "\n", + "Chord comparison can be considered as a set matching problem so that the correctness of a chord can be determined by the overlap between response set and reference set. Additionally, it may require handling partial matches problem (e.g. two out of three chord notes played correctly), consider defining it as imperfect chords. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "db6e1544", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "import numpy as np\n", + "\n", + "cwd = os.getcwd()\n", + "\n", + "cwd = os.getcwd()\n", + "dir = os.path.dirname(cwd)\n", + "reference_path = os.path.join(dir, \"data\", \"referenceMIDIchords.json\")\n", + "response_path = os.path.join(dir, \"data\", \"responseMIDIchords.json\")\n", + "\n", + "with open(reference_path) as f1:\n", + " ref_chords_data = json.load(f1)\n", + "\n", + "with open(response_path) as f2:\n", + " res_chords_data = json.load(f2)" + ] + }, + { + "cell_type": "markdown", + "id": "ae192f73", + "metadata": {}, + "source": [ + "chord accuracy metric: https://arxiv.org/pdf/2201.05244" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc06c1bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Default threshold: notes starting within 50ms are grouped as one chord.\n", + "DEFAULT_CHORD_ONSET_WINDOW = 0.05\n", + "\n", + "# Chord template dictionary.\n", + "# Each entry maps a chord name to a frozenset of pitch class intervals,\n", + "# where the root note is normalised to 0.\n", + "# Only the 4 triad types are included, following Muller (2021) Chapter 5.\n", + "CHORD_TEMPLATES = {\n", + " \"major\": frozenset([0, 4, 7]),\n", + " \"minor\": frozenset([0, 3, 7]),\n", + " \"diminished\": frozenset([0, 3, 6]),\n", + " \"augmented\": frozenset([0, 4, 8]),\n", + "}\n", + " \n", + "# Pitch class names for human-readable feedback messages.\n", + "PITCH_CLASS_NAMES = [\"C\", \"C#\", \"D\", \"D#\", \"E\", \"F\",\n", + " \"F#\", \"G\", \"G#\", \"A\", \"A#\", \"B\"]\n", + "\n", + "# helper function to build an event dict from a group of notes\n", + "def make_event(notes_in_group):\n", + " \"\"\"\n", + " Build a single event dict from a group of notes.\n", + " \n", + " Args:\n", + " notes_in_group: list of one or more note dicts.\n", + " \n", + " Returns:\n", + " event dict with keys \"event_type\", \"notes\", \"start\". example:\n", + " {\n", + " \"event_type\": \"note\" or \"chord\" depending on the number of notes in the group\n", + " \"notes\": [ \n", + " { \n", + " \"pitch\": int\n", + " \"start\": float\n", + " \"duration\": float\n", + " },\n", + " ]\n", + " \"start\": float,\n", + " }\n", + " \"\"\"\n", + " event_type = \"note\" if len(notes_in_group) == 1 else \"chord\"\n", + " return {\n", + " \"event_type\": event_type,\n", + " \"notes\": notes_in_group,\n", + " \"start\": notes_in_group[0][\"start\"],\n", + " }\n", + "\n", + "# group notes into events based on their start times\n", + "def group_notes_into_events(notes, chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW):\n", + " \"\"\"\n", + " Group a flat list of notes into events. Notes whose start times fall\n", + " within chord_onset_window seconds of each other are placed into the\n", + " same event (i.e. treated as a chord). Notes that are not grouped with\n", + " any other note form a single-note event.\n", + "\n", + " Args:\n", + " notes: list of dicts, each with keys 'pitch', 'start', 'duration'\n", + " chord_onset_window: float, max time difference (seconds) to be grouped\n", + "\n", + " Returns:\n", + " list of dicts, each with keys:\n", + " 'pitches' : set of int MIDI pitch numbers\n", + " 'start' : float, earliest start time in the group\n", + " 'duration' : float, average duration of notes in the group\n", + " \"\"\"\n", + " if len(notes) == 0:\n", + " return []\n", + "\n", + " # Sort notes by start time first\n", + " sorted_notes = sorted(notes, key=lambda n: n[\"start\"])\n", + "\n", + " events = []\n", + " current_group = [sorted_notes[0]]\n", + " group_start = sorted_notes[0][\"start\"]\n", + "\n", + " for note in sorted_notes[1:]:\n", + " # Close enough in time: add to current group (chord)\n", + " if note[\"start\"] - group_start <= chord_onset_window:\n", + " current_group.append(note)\n", + " else:\n", + " # Too far apart: save current chord, start a new group\n", + " event = make_event(current_group)\n", + " events.append(event)\n", + " current_group = [note]\n", + " group_start = note[\"start\"]\n", + "\n", + " # append the last group\n", + " last_event = make_event(current_group)\n", + " events.append(last_event)\n", + "\n", + " return events\n", + "\n", + "\n", + "def compute_event_cost(event_a, event_b):\n", + " pass" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "compareMusic", + "language": "python", + "name": "comparemusic" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From e32ddf6324de53ac359269d2df1ad1c7b626d7d4 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Fri, 26 Jun 2026 21:32:47 +0100 Subject: [PATCH 2/8] finish the chord comparison logic, need to be tested --- notebooks/Chord_comparison.ipynb | 1126 +++++++++++++++++++++++++++++- 1 file changed, 1117 insertions(+), 9 deletions(-) diff --git a/notebooks/Chord_comparison.ipynb b/notebooks/Chord_comparison.ipynb index 6240deb..35ceae3 100644 --- a/notebooks/Chord_comparison.ipynb +++ b/notebooks/Chord_comparison.ipynb @@ -16,20 +16,20 @@ "A \"chord\" = multiple(>=3) notes with the same or very close start times\n", "- Block Chords: multiple(>=3) notes with the same start time\n", "- Broken Chords / Arpeggios: multiple(>=3) notes play one after another, in a given order\n", + "focus on block chords first. \n", "\n", "Chord comparison can be considered as a set matching problem so that the correctness of a chord can be determined by the overlap between response set and reference set. Additionally, it may require handling partial matches problem (e.g. two out of three chord notes played correctly), consider defining it as imperfect chords. " ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "id": "db6e1544", "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", - "import numpy as np\n", "\n", "cwd = os.getcwd()\n", "\n", @@ -50,12 +50,74 @@ "id": "ae192f73", "metadata": {}, "source": [ - "chord accuracy metric: https://arxiv.org/pdf/2201.05244" + "chord accuracy metric: https://arxiv.org/pdf/2201.05244\n", + "\n", + "consider chord as unordered set of pitch classes\n", + "https://www.researchgate.net/profile/Pierre-Hanna/publication/265984596_A_Survey_Of_Chord_Distances_With_Comparison_For_Chord_Analysis/links/555c4c7808aec5ac2232b158/A-Survey-Of-Chord-Distances-With-Comparison-For-Chord-Analysis.pdf" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, + "id": "c04a662c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "# Default thresholds / parameters\n", + "# Teachers can override any of these via the params dict in evaluation_function.\n", + "# ------------------------------------------------------------------------------\n", + "# Gap penalty: cost of leaving a note unaligned (insertion/deletion)\n", + "DEFAULT_GAP_PENALTY = 6\n", + "\n", + "# Timing: |response_start - predicted_start| / IOI must be below this.\n", + "# e.g. 0.20 means the start can be off by up to 20% of the inter-onset interval.\n", + "TIMING_RELATIVE_THRESHOLD = 0.20\n", + "\n", + "# Duration: |response_dur / ref_dur - 1| must be below this.\n", + "# e.g. 0.25 means the student's duration can be off by up to 25% of the reference.\n", + "DURATION_RELATIVE_THRESHOLD = 0.25\n", + "\n", + "# Thresholds that trigger a global tempo comment in the overview.\n", + "GLOBAL_SLOW_THRESHOLD = 1.15 # timing_scale > 1.15 -> \"overall too slow\"\n", + "GLOBAL_FAST_THRESHOLD = 0.85 # timing_scale < 0.85 -> \"overall too fast\"\n", + "\n", + "# Step 0 - make first note start at t = 0.0\n", + "# ------------------------------------------------------------------------------\n", + "def normalize_start_times(notes):\n", + " \"\"\"\n", + " Shift all notes so that the first note starts at t=0.\n", + " \n", + " Args:\n", + " notes: list of note dicts, each with at least a \"start\" key.\n", + " \n", + " Returns:\n", + " A new list of note dicts (copies, not the original objects), with\n", + " every \"start\" value shifted so notes[0][\"start\"] == 0. Returns an\n", + " empty list unchanged if notes is empty.\n", + " \"\"\"\n", + " if not notes:\n", + " return []\n", + " \n", + " first_start = notes[0][\"start\"]\n", + " \n", + " shifted_notes = []\n", + " for note in notes:\n", + " # Create a copy of the note dict with the \"start\" time shifted\n", + " note_copy = {\n", + " \"pitch\": note[\"pitch\"],\n", + " \"start\": note[\"start\"] - first_start,\n", + " \"duration\": note[\"duration\"],\n", + " }\n", + " shifted_notes.append(note_copy)\n", + " \n", + " return shifted_notes" + ] + }, + { + "cell_type": "code", + "execution_count": 13, "id": "fc06c1bb", "metadata": {}, "outputs": [], @@ -97,16 +159,21 @@ " \"duration\": float\n", " },\n", " ]\n", - " \"start\": float,\n", + " \"event_start\": float,\n", + " \"event_duration\": float\n", " }\n", " \"\"\"\n", " event_type = \"note\" if len(notes_in_group) == 1 else \"chord\"\n", " return {\n", " \"event_type\": event_type,\n", " \"notes\": notes_in_group,\n", - " \"start\": notes_in_group[0][\"start\"],\n", + " # use the start time of the first note in the group as the event start\n", + " \"event_start\": notes_in_group[0][\"start\"], \n", + " # use the duration of the first note in the group as the event duration\n", + " \"event_duration\": notes_in_group[0][\"duration\"], \n", " }\n", "\n", + "\n", "# group notes into events based on their start times\n", "def group_notes_into_events(notes, chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW):\n", " \"\"\"\n", @@ -150,11 +217,1052 @@ " last_event = make_event(current_group)\n", " events.append(last_event)\n", "\n", - " return events\n", + " return events" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6dbac7c6", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_note_cost(note1, note2):\n", + " \"\"\"\n", + " Cost of aligning (replacing) one note with another, based on pitch.\n", + " \n", + " cost = 0: pitches are identical (a 'match'). \n", + " cost > 0: different pitches (a 'replacement')\n", + " \n", + " Args:\n", + " note1: dict with keys \"pitch\" (int), \"start\" (float), \"duration\" (float)\n", + " note2: dict with keys \"pitch\" (int), \"start\" (float), \"duration\" (float)\n", + " \n", + " Returns:\n", + " int: cost value >= 0 (lower means more similar pitch)\n", + " \"\"\"\n", + " return int(abs(note1[\"pitch\"] - note2[\"pitch\"]))\n", + "\n", + "\n", + "def compute_event_cost(event1, event2, gap_penalty=DEFAULT_GAP_PENALTY):\n", + " \"\"\"\n", + " Cost of aligning (substituting) one event with another.\n", + " \n", + " Rules:\n", + " - note vs note: absolute pitch difference \n", + " - chord vs chord: Hamming distance on 12-dim pitch class binary vectors\n", + " - note vs chord (type mismatch): return gap_penalty so the aligner\n", + " treats them as unaligned (insertion + deletion is preferred)\n", + " \n", + " For chord vs chord, the Hamming distance counts how many of the 12\n", + " pitch classes differ between the two chords (symmetric difference size).\n", + " \n", + " Args:\n", + " event1: event dict (from group_notes_into_events)\n", + " event2: event dict (from group_notes_into_events)\n", + " gap_penalty: cost of an unaligned event; used as the type-mismatch cost\n", + " \n", + " Returns:\n", + " int: alignment cost >= 0\n", + " \"\"\"\n", + " type1 = event1[\"event_type\"]\n", + " type2 = event2[\"event_type\"]\n", + " \n", + " # Type mismatch: note vs chord or chord vs note.\n", + " # Return gap_penalty so alignment prefers to leave them unmatched.\n", + " if type1 != type2:\n", + " return gap_penalty\n", + " \n", + " # Both are single notes: use pitch difference (same as Phase 1)\n", + " elif type1 == \"note\" and type2 == \"note\":\n", + " return compute_note_cost(event1[\"notes\"][0], event2[\"notes\"][0])\n", + " \n", + " # Both are chords: Hamming distance on 12-dimensional pitch class vectors.\n", + " else:\n", + " vec1 = [0] * 12\n", + " vec2 = [0] * 12\n", + " for note in event1[\"notes\"]:\n", + " vec1[note[\"pitch\"] % 12] = 1\n", + " for note in event2[\"notes\"]:\n", + " vec2[note[\"pitch\"] % 12] = 1\n", + " \n", + " hamming = sum(1 for i in range(12) if vec1[i] != vec2[i])\n", + " return hamming" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "1f756e2f", + "metadata": {}, + "outputs": [], + "source": [ + "def event_alignment_ED(response_events, ref_events, gap_penalty=DEFAULT_GAP_PENALTY):\n", + " \"\"\"\n", + " Align events (notes or chords) using edit distance (ED). \n", + " The ED allows for insertions and deletions, which can be useful for \n", + " evaluating musical practice containing missing/extra notes.\n", + " \n", + " Args:\n", + " response_events: list of event dicts from group_notes_into_events\n", + " ref_events: list of event dicts from group_notes_into_events\n", + " gap_penalty: cost of leaving an event unaligned (insertion/deletion)\n", + " \n", + " Returns:\n", + " operations: list of transformation ops dicts, in order from first event to last:\n", + " {'type': 'match' or 'replacement' or 'missing' or 'extra', \n", + " 'response_idx': int or None, \n", + " 'reference_idx': int or None, \n", + " 'cost': int}\n", + " D: accumulated cost matrix, shape (N+1, M+1)\n", + " \"\"\"\n", + " # the rows of D correspond to response events\n", + " N = len(response_events)\n", + " # the columns of D correspond to reference events\n", + " M = len(ref_events)\n", + "\n", + " # Build the accumulated cost matrix D of size (N+1 x M+1)\n", + " D = np.zeros((N + 1, M + 1), dtype=int)\n", + " \n", + " # Boundary conditions: aligning against an empty sequence means every event\n", + " # is unaligned, so the cost is n (or m) times the gap penalty.\n", + " for n in range(1, N + 1):\n", + " D[n, 0] = n * gap_penalty # n extra response events\n", + " for m in range(1, M + 1):\n", + " D[0, m] = m * gap_penalty # m missing ref events\n", + "\n", + " # Recursion (accumulated cost / score matrix D):\n", + " for n in range(1, N + 1):\n", + " for m in range(1, M + 1):\n", + " replace_cost = compute_event_cost(response_events[n-1], ref_events[m-1], gap_penalty)\n", + " D[n, m] = min(\n", + " D[n-1, m-1] + replace_cost, # diagonal: match or replacement\n", + " D[n-1, m] + gap_penalty, # vertical: extra event response[n-1]\n", + " D[n, m-1] + gap_penalty, # horizontal: missing response for ref[m-1]\n", + " )\n", + "\n", + " # Backtrack and classify each transformation op based on movement direction in D\n", + " operations = []\n", + " n, m = N, M\n", + " while n > 0 or m > 0:\n", + " # Boundary conditions: at the top row, only horizontal moves possible\n", + " if n == 0:\n", + " # Missing response for ref[m-1] (deletion)\n", + " operations.append({\n", + " \"type\": \"missing\",\n", + " \"response_idx\": None,\n", + " \"reference_idx\": m - 1,\n", + " \"cost\": gap_penalty,\n", + " })\n", + " m -= 1\n", + " # At the leftmost column, only vertical moves possible\n", + " elif m == 0:\n", + " # Extra event response[n-1] (insertion)\n", + " operations.append({\n", + " \"type\": \"extra\",\n", + " \"response_idx\": n - 1,\n", + " \"reference_idx\": None,\n", + " \"cost\": gap_penalty,\n", + " })\n", + " n -= 1\n", + " # For all other cases, we can move in any direction (diagonal, vertical, horizontal)\n", + " else:\n", + " replace_cost = compute_event_cost(response_events[n - 1], ref_events[m - 1], gap_penalty)\n", + " diag = D[n - 1, m - 1] + replace_cost # diagonal: match or replacement\n", + " up = D[n - 1, m] + gap_penalty # vertical: extra event response[n-1]\n", + " left = D[n, m - 1] + gap_penalty # horizontal: missing response for ref[m-1]\n", + " min_cost = min(diag, up, left) # find the minimum cost step\n", + "\n", + " # classify the transformation ops based on the minimum cost step\n", + " if min_cost == diag: # Diagonal -> two events are aligned (match or replacement)\n", + " operations.append({\n", + " \"type\": \"match\" if replace_cost == 0 else \"replacement\",\n", + " \"response_idx\": n - 1,\n", + " \"reference_idx\": m - 1,\n", + " \"cost\": replace_cost,\n", + " })\n", + " n, m = n - 1, m - 1\n", + " elif min_cost == up: # Vertical -> response[n-1] is extra (insertion)\n", + " operations.append({\n", + " \"type\": \"extra\",\n", + " \"response_idx\": n - 1,\n", + " \"reference_idx\": None,\n", + " \"cost\": gap_penalty,\n", + " })\n", + " n -= 1\n", + " else: # Horizontal -> response is missing for ref[m-1] (deletion)\n", + " operations.append({\n", + " \"type\": \"missing\",\n", + " \"response_idx\": None,\n", + " \"reference_idx\": m - 1,\n", + " \"cost\": gap_penalty,\n", + " })\n", + " m -= 1\n", + "\n", + " operations.reverse() # Reverse to get ops in order from first note to last\n", + " return operations, D\n", + "\n", + "\n", + "def estimate_global_timing(operations, response_events, ref_events):\n", + " \"\"\"\n", + " Estimate the student's overall tempo relative to the reference, by fitting\n", + " a straight line through the matched note start times:\n", + " response_start ≈ scale * ref_start + offset\n", + " where:\n", + " scale > 1 means the student is playing slower overall\n", + " scale < 1 means the student is playing faster overall\n", + " offset captures any constant time shift\n", + "\n", + " Args:\n", + " operations: list of operation dicts (match/replacement/missing/extra)\n", + " response_events: list of event dicts from response\n", + " ref_events: list of event dicts from reference\n", + "\n", + " Returns:\n", + " scale: float, estimated tempo ratio (1.0 = same speed as reference)\n", + " offset: float (seconds), estimated constant time shift\n", + " \"\"\"\n", + " # Collect (ref_start, response_start) pairs from matched/replaced notes only.\n", + " # Missing/extra notes have no pair, so they cannot contribute to the fit.\n", + " ref_starts = []\n", + " response_starts = []\n", + " for op in operations:\n", + " if op[\"type\"] in (\"match\", \"replacement\"):\n", + " res = op[\"response_idx\"]\n", + " ref = op[\"reference_idx\"]\n", + " ref_starts.append(ref_events[ref][\"event_start\"])\n", + " response_starts.append(response_events[res][\"event_start\"])\n", + " \n", + " # Not enough points for fitting a meaningful line — assume no drift in tempo.\n", + " if len(ref_starts) < 3:\n", + " return 1.0, 0.0\n", + " \n", + " x = np.array(ref_starts, dtype=float)\n", + " y = np.array(response_starts, dtype=float)\n", + " \n", + " # Least-squares line fit: y = scale * x + offset\n", + " scale, offset = np.polyfit(x, y, 1)\n", + " \n", + " return float(scale), float(offset)\n", + "\n", + "\n", + "def estimate_global_duration_scale(operations, response_events, ref_events):\n", + " \"\"\"\n", + " Estimate the student's overall note-length scale relative to the reference,\n", + " by fitting a line through the origin:\n", + " response_duration ≈ duration_scale * ref_duration\n", + " where:\n", + " duration_scale > 1 means notes are held longer overall\n", + " duration_scale < 1 means notes are held shorter overall\n", + "\n", + " Args:\n", + " operations: output of note_alignment_ED()\n", + " response_events: list of student event dicts\n", + " ref_events: list of reference event dicts\n", + "\n", + " Returns:\n", + " duration_scale (float): estimated duration ratio (1.0 = same as reference)\n", + " \"\"\"\n", + " ref_durations = []\n", + " response_durations = []\n", + " for op in operations:\n", + " if op[\"type\"] in (\"match\", \"replacement\"):\n", + " res = op[\"response_idx\"]\n", + " ref = op[\"reference_idx\"]\n", + " ref_durations.append(ref_events[ref][\"event_duration\"])\n", + " response_durations.append(response_events[res][\"event_duration\"])\n", + "\n", + " if len(ref_durations) < 3:\n", + " return 1.0\n", + "\n", + " x = np.array(ref_durations, dtype=float)\n", + " y = np.array(response_durations, dtype=float)\n", + "\n", + " # Least-squares fit through the origin: y = scale * x\n", + " # Closed-form solution: scale = sum(x*y) / sum(x*x)\n", + " duration_scale = float(np.sum(x * y) / np.sum(x * x))\n", + "\n", + " return duration_scale" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "bdb915ea", + "metadata": {}, + "outputs": [], + "source": [ + "# Chord helper functions\n", + "def get_pitch_class_set(notes):\n", + " \"\"\"\n", + " Convert a list of notes to a set of pitch classes (each pitch mod 12).\n", + "\n", + " Args:\n", + " notes: list of note dicts, each with a \"pitch\" key.\n", + "\n", + " Returns:\n", + " set of ints, each in range [0, 11].\n", + " \"\"\"\n", + " return set(note[\"pitch\"] % 12 for note in notes)\n", + " \n", + " \n", + "def identify_chord_name(notes):\n", + " \"\"\"\n", + " Identify the chord name (e.g. \"C major\", \"A minor\") from a list of notes\n", + " by matching their pitch class set against CHORD_TEMPLATES.\n", + " \n", + " For each candidate root in the pitch class set, normalise all pitch classes\n", + " to start at 0 and check against each template. If no match is found,\n", + " returns \"unknown chord\".\n", + " \n", + " Args:\n", + " notes: list of note dicts, each with a \"pitch\" key.\n", + " \n", + " Returns:\n", + " str: chord name e.g. \"C major\", or \"unknown chord\".\n", + " \"\"\"\n", + " pitch_classes = get_pitch_class_set(notes)\n", + " \n", + " for root_pc in pitch_classes:\n", + " normalised = set((pc - root_pc) % 12 for pc in pitch_classes)\n", + " for chord_type, template in CHORD_TEMPLATES.items():\n", + " if normalised == template:\n", + " root_name = PITCH_CLASS_NAMES[root_pc]\n", + " return root_name + \" \" + chord_type\n", + " \n", + " return \"unknown chord\"\n", + " \n", + " \n", + "def compute_chord_accuracy(ref_notes, res_notes):\n", + " \"\"\"\n", + " Compute the chord accuracy score A from McLeod & Rohit (2022),\n", + " \n", + " A = (C - I + |y|) / (2 * |y|)\n", + " \n", + " where:\n", + " C = |y ∩ y_hat| (correctly played pitch classes)\n", + " I = |y_hat - y| (extra pitch classes played)\n", + " |y| (number of pitch classes in the reference chord)\n", + " \n", + " A = 1.0 means perfectly correct. A = 0.0 means nothing correct and\n", + " many extra notes are played.\n", + " \n", + " Args:\n", + " ref_notes: list of note dicts for the reference chord.\n", + " res_notes: list of note dicts for the response chord.\n", + " \n", + " Returns:\n", + " accuracy: float in [0, 1]\n", + " correct_pitches: sorted list of pitch class ints in both chords\n", + " missing_pitches: sorted list of pitch class ints in ref\n", + " extra_pitches:sorted list of pitch class ints in response\n", + " \"\"\"\n", + " ref_pcs = get_pitch_class_set(ref_notes)\n", + " res_pcs = get_pitch_class_set(res_notes)\n", + " \n", + " correct_pcs = ref_pcs & res_pcs\n", + " missing_pcs = ref_pcs - res_pcs\n", + " extra_pcs = res_pcs - ref_pcs\n", + " \n", + " C = len(correct_pcs)\n", + " I = len(extra_pcs)\n", + " ref_size = len(ref_pcs)\n", + " \n", + " if ref_size == 0:\n", + " accuracy = 0.0\n", + " else:\n", + " accuracy = (C - I + ref_size) / (2.0 * ref_size)\n", + " accuracy = max(0.0, min(1.0, accuracy))\n", + " \n", + " return accuracy, sorted(correct_pcs), sorted(missing_pcs), sorted(extra_pcs)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "88c61aa5", + "metadata": {}, + "outputs": [], + "source": [ + "def event_level_feedback(operations, response_events, ref_events,\n", + " timing_scale=1.0, timing_offset=0.0,\n", + " duration_scale=1.0,\n", + " timing_relative_threshold=TIMING_RELATIVE_THRESHOLD,\n", + " duration_relative_threshold=DURATION_RELATIVE_THRESHOLD):\n", + " \"\"\"\n", + " Analyse each aligned event pair (or missing/extra event) and return a\n", + " list of event result dicts.\n", + " For single-note events, pitch evaluation uses absolute pitch difference.\n", + " For chord events, pitch evaluation uses the chord accuracy metric A:\n", + " A = (C - I + |y|) / (2 * |y|)\n", + " where C = correctly played pitch classes, I = extra pitch classes played,\n", + " |y| = number of pitch classes in the reference chord.\n", + "\n", + " Args:\n", + " operations: list of op dicts (match/replacement/missing/extra)\n", + " response_events: list of event dicts from response\n", + " ref_events: list of event dicts from reference\n", + " timing_scale: float, estimated tempo ratio (1.0 = same speed as reference)\n", + " timing_offset: float (seconds), estimated constant time shift\n", + " duration_scale: float, estimated overall duration ratio\n", + " timing_relative_threshold: float, relative tolerance for timing correctness\n", + " duration_relative_threshold: float, relative tolerance for duration correctness\n", + "\n", + " Returns:\n", + " event_level_results : list of dicts, each dict contains:\n", + "\n", + " For note events, each dict has:\n", + " \"event_type\" -> \"note\"\n", + " \"reference_index\" -> int (1-based) or None\n", + " \"response_index\" -> int (1-based) or None\n", + " \"operation_type\" -> \"match\", \"replacement\", \"missing\", \"extra\"\n", + " \"pitch_correct\" -> bool\n", + " \"pitch_diff\" -> int (semitones) or None\n", + " \"timing_correct\" -> bool\n", + " \"timing_abs_diff\" -> float (seconds) or None\n", + " \"timing_relative_diff\" -> float or None\n", + " \"duration_correct\" -> bool\n", + " \"duration_abs_diff\" -> float (seconds) or None\n", + " \"duration_relative_diff\" -> float or None\n", + " \n", + " For chord events, each dict has:\n", + " \"event_type\" -> \"chord\"\n", + " \"reference_index\" -> int (1-based) or None\n", + " \"response_index\" -> int (1-based) or None\n", + " \"operation_type\" -> \"match\", \"replacement\", \"missing\", \"extra\"\n", + " \"chord_name_ref\" -> str e.g. \"C major\", or None\n", + " \"chord_name_res\" -> str e.g. \"C minor\", or None\n", + " \"chord_accuracy\" -> float (0 to 1) or None\n", + " \"correct_pitches\" -> list of pitch class ints or None\n", + " \"missing_pitches\" -> list of pitch class ints or None\n", + " \"extra_pitches\" -> list of pitch class ints or None\n", + " \"timing_correct\" -> bool\n", + " \"timing_abs_diff\" -> float (seconds) or None\n", + " \"timing_relative_diff\" -> float or None\n", + " \"duration_correct\" -> bool\n", + " \"duration_abs_diff\" -> float (seconds) or None\n", + " \"duration_relative_diff\" -> float or None\n", + " \"\"\"\n", + " # Compute IOI for each reference note: ioi[m] = ref_notes[m][\"start\"] - ref_notes[m-1][\"start\"]\n", + " # floor at 0.05s to avoid division by zero issues\n", + " ref_ioi = [None] * len(ref_events)\n", + " for m in range(1, len(ref_events)):\n", + " interval = ref_events[m][\"event_start\"] - ref_events[m - 1][\"event_start\"]\n", + " ref_ioi[m] = max(interval, 0.05)\n", + " \n", + " event_level_results = []\n", + "\n", + " for op in operations:\n", + " res_idx = op[\"response_idx\"]\n", + " ref_idx = op[\"reference_idx\"]\n", + " op_type = op[\"type\"]\n", + "\n", + " # Determine event type from whichever side is available\n", + " if ref_idx is not None:\n", + " event_type = ref_events[ref_idx][\"event_type\"]\n", + " else:\n", + " event_type = response_events[res_idx][\"event_type\"]\n", + "\n", + " # Missing/extra notes: no pitch/timing/duration comparison is possible,\n", + " # so all the numeric fields are set to None.\n", + " if op_type in (\"missing\", \"extra\"):\n", + " if event_type == \"note\":\n", + " event_level_results.append({\n", + " \"event_type\": \"note\",\n", + " \"reference_index\": (ref_idx + 1) if ref_idx is not None else None,\n", + " \"response_index\": (res_idx + 1) if res_idx is not None else None,\n", + " \"operation_type\": op_type,\n", + " \"pitch_correct\": False,\n", + " \"pitch_diff\": None,\n", + " \"timing_correct\": False,\n", + " \"timing_abs_diff\": None,\n", + " \"timing_relative_diff\": None,\n", + " \"duration_correct\": False,\n", + " \"duration_abs_diff\": None,\n", + " \"duration_relative_diff\": None,\n", + " })\n", + " else:\n", + " if ref_idx is not None:\n", + " chord_name = identify_chord_name(ref_events[ref_idx][\"notes\"])\n", + " else:\n", + " chord_name = identify_chord_name(\n", + " response_events[res_idx][\"notes\"]\n", + " )\n", + " event_level_results.append({\n", + " \"event_type\": \"chord\",\n", + " \"reference_index\": (ref_idx + 1) if ref_idx is not None else None,\n", + " \"response_index\": (res_idx + 1) if res_idx is not None else None,\n", + " \"operation_type\": op_type,\n", + " \"chord_name_ref\": chord_name if op_type == \"missing\" else None,\n", + " \"chord_name_res\": chord_name if op_type == \"extra\" else None,\n", + " \"chord_accuracy\": None,\n", + " \"correct_pitches\": None,\n", + " \"missing_pitches\": None,\n", + " \"extra_pitches\": None,\n", + " \"timing_correct\": False,\n", + " \"timing_abs_diff\": None,\n", + " \"timing_relative_diff\": None,\n", + " \"duration_correct\": False,\n", + " \"duration_abs_diff\": None,\n", + " \"duration_relative_diff\": None,\n", + " })\n", + " else:\n", + " # Aligned event pair (match or replacement)\n", + " res_event = response_events[res_idx]\n", + " ref_event = ref_events[ref_idx]\n", + "\n", + " # Timing — residual after removing the global tempo trend\n", + " predicted_start = timing_scale * ref_event[\"event_start\"] + timing_offset\n", + " timing_abs_diff = abs(res_event[\"event_start\"] - predicted_start)\n", + " if ref_idx == 0:\n", + " # First note will start at 0, so no difference.\n", + " timing_relative_diff = None\n", + " timing_correct = True\n", + " else:\n", + " ioi = ref_ioi[ref_idx]\n", + " timing_relative_diff = timing_abs_diff / ioi\n", + " timing_correct = (timing_relative_diff <= timing_relative_threshold)\n", + "\n", + " # Duration — residual after removing the global duration-scale trend\n", + " predicted_duration = duration_scale * ref_event[\"event_duration\"]\n", + " duration_abs_diff = res_event[\"event_duration\"] - predicted_duration\n", + " ref_dur = max(ref_event[\"event_duration\"], 0.05) # floor at 0.05s to avoid division by zero issues\n", + " duration_relative_diff = duration_abs_diff / ref_dur\n", + " duration_correct = (abs(duration_relative_diff) <= duration_relative_threshold)\n", + " if event_type == \"note\":\n", + " # For single-note events, pitch correctness is based on absolute pitch difference.\n", + " pitch1 = res_event[\"notes\"][0][\"pitch\"]\n", + " pitch2 = ref_event[\"notes\"][0][\"pitch\"]\n", + " pitch_diff = int(abs(pitch1 - pitch2))\n", + " event_level_results.append({\n", + " \"event_type\": \"note\",\n", + " \"reference_index\": ref_idx + 1,\n", + " \"response_index\": res_idx + 1,\n", + " \"operation_type\": op_type,\n", + " \"pitch_correct\": (pitch_diff == 0),\n", + " \"pitch_diff\": pitch_diff,\n", + " \"timing_correct\": timing_correct,\n", + " \"timing_abs_diff\": timing_abs_diff,\n", + " \"timing_relative_diff\": timing_relative_diff,\n", + " \"duration_correct\": duration_correct,\n", + " \"duration_abs_diff\": duration_abs_diff,\n", + " \"duration_relative_diff\": duration_relative_diff,\n", + " })\n", + " else:\n", + " # For chord events, pitch correctness is based on the chord accuracy metric.\n", + " accuracy, correct_pcs, missing_pcs, extra_pcs = (\n", + " compute_chord_accuracy(\n", + " ref_event[\"notes\"], res_event[\"notes\"]\n", + " )\n", + " )\n", + " event_level_results.append({\n", + " \"event_type\": \"chord\",\n", + " \"reference_index\": ref_idx + 1,\n", + " \"response_index\": res_idx + 1,\n", + " \"operation_type\": op_type,\n", + " \"chord_name_ref\": identify_chord_name(ref_event[\"notes\"]),\n", + " \"chord_name_res\": identify_chord_name(res_event[\"notes\"]),\n", + " \"chord_accuracy\": accuracy,\n", + " \"correct_pitches\": correct_pcs,\n", + " \"missing_pitches\": missing_pcs,\n", + " \"extra_pitches\": extra_pcs,\n", + " \"timing_correct\": timing_correct,\n", + " \"timing_abs_diff\": timing_abs_diff,\n", + " \"timing_relative_diff\": timing_relative_diff,\n", + " \"duration_correct\": duration_correct,\n", + " \"duration_abs_diff\": duration_abs_diff,\n", + " \"duration_relative_diff\": duration_relative_diff,\n", + " })\n", + "\n", + " return event_level_results\n", + "\n", + "\n", + "\n", + "def compute_stats(event_level_results, ref_events, timing_scale=1.0,\n", + " timing_offset=0.0, duration_scale=1.0):\n", + " \"\"\"\n", + " Compute summary counts and correctness booleans from event-level feedback.\n", + "\n", + " Args:\n", + " event_level_results: list of dicts, output of event_level_feedback()\n", + " ref_events: list of reference event dicts\n", + " timing_scale: float, from estimate_global_timing()\n", + " timing_offset: float, from estimate_global_timing()\n", + " duration_scale: float, from estimate_global_duration_scale()\n", + "\n", + " Returns:\n", + " stats: dict with keys:\n", + " \"pitch_all_correct\" -> bool, True if all note pitches are\n", + " correct AND all chord accuracies are 1.0\n", + " \"timing_all_correct\" -> bool\n", + " \"duration_all_correct\" -> bool\n", + " \"total_notes_in_reference\" -> int\n", + " \"total_notes_missing\" -> int\n", + " \"total_notes_extra\" -> int\n", + " \"total_notes_wrong_pitch\" -> int\n", + " \"total_notes_wrong_timing\" -> int\n", + " \"total_notes_wrong_duration\" -> int\n", + " \"total_notes_correct\" -> int\n", + " \"timing_scale\" -> float\n", + " \"timing_offset\" -> float\n", + " \"duration_scale\" -> float\n", + " \"total_chords_in_reference\" -> int\n", + " \"total_chords_missing\" -> int\n", + " \"total_chords_extra\" -> int\n", + " \"total_chords_correct\" -> int (accuracy == 1.0)\n", + " \"total_chords_imperfect\" -> int (0.0 < accuracy < 1.0)\n", + " \"total_chords_wrong\" -> int (accuracy == 0.0)\n", + " \"\"\"\n", + " note_events = [n for n in event_level_results if n[\"event_type\"] == \"note\"]\n", + " chord_events = [ch for ch in event_level_results if ch[\"event_type\"] == \"chord\"]\n", + " \n", + " ref_note_count = sum(1 for n in ref_events if n[\"event_type\"] == \"note\")\n", + " ref_chord_count = sum(1 for ch in ref_events if ch[\"event_type\"] == \"chord\")\n", + " \n", + " paired_notes = [\n", + " n for n in note_events\n", + " if n[\"operation_type\"] in (\"match\", \"replacement\")\n", + " ]\n", + " paired_chords = [\n", + " ch for ch in chord_events\n", + " if ch[\"operation_type\"] in (\"match\", \"replacement\")\n", + " ]\n", + " all_paired = [\n", + " e for e in event_level_results\n", + " if e[\"operation_type\"] in (\"match\", \"replacement\")\n", + " ]\n", + "\n", + " stats = {\n", + " \"pitch_all_correct\": (\n", + " all(n[\"pitch_correct\"] for n in paired_notes)\n", + " and all(\n", + " ch[\"chord_accuracy\"] == 1.0\n", + " for ch in paired_chords\n", + " if ch[\"chord_accuracy\"] is not None\n", + " )\n", + " ),\n", + " \"timing_all_correct\": all(n[\"timing_correct\"] for n in all_paired),\n", + " \"duration_all_correct\": all(n[\"duration_correct\"] for n in all_paired),\n", + " \"total_notes_in_reference\": ref_note_count,\n", + " \"total_notes_missing\": sum(1 for n in note_events if n[\"operation_type\"] == \"missing\"),\n", + " \"total_notes_extra\": sum(1 for n in note_events if n[\"operation_type\"] == \"extra\"),\n", + " \"total_notes_wrong_pitch\": sum(1 for n in paired_notes if not n[\"pitch_correct\"]),\n", + " \"total_notes_wrong_timing\": sum(1 for n in paired_notes if not n[\"timing_correct\"]),\n", + " \"total_notes_wrong_duration\": sum(1 for n in paired_notes if not n[\"duration_correct\"]),\n", + " \"total_notes_correct\": sum(1 for n in paired_notes\n", + " if n[\"pitch_correct\"] and n[\"timing_correct\"] and n[\"duration_correct\"]\n", + " ),\n", + " \"timing_scale\": timing_scale,\n", + " \"timing_offset\": timing_offset,\n", + " \"duration_scale\": duration_scale,\n", + " \"total_chords_in_reference\": ref_chord_count,\n", + " \"total_chords_missing\": sum(1 for ch in chord_events if ch[\"operation_type\"] == \"missing\"),\n", + " \"total_chords_extra\": sum(1 for ch in chord_events if ch[\"operation_type\"] == \"extra\"),\n", + " \"total_chords_correct\": sum(\n", + " 1 for ch in paired_chords\n", + " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] == 1.0\n", + " ),\n", + " \"total_chords_imperfect\": sum(\n", + " 1 for ch in paired_chords\n", + " if ch[\"chord_accuracy\"] is not None\n", + " and ch[\"chord_accuracy\"] > 0.0\n", + " and ch[\"chord_accuracy\"] < 1.0\n", + " ),\n", + " \"total_chords_wrong\": sum(\n", + " 1 for ch in paired_chords\n", + " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] == 0.0\n", + " ),\n", + " }\n", + " return stats\n", + "\n", + "\n", + "def generate_feedback_message(event_details, response_events, ref_events, stats,\n", + " global_slow_threshold=GLOBAL_SLOW_THRESHOLD,\n", + " global_fast_threshold=GLOBAL_FAST_THRESHOLD):\n", + " \"\"\"\n", + " Generate human-readable feedback messages for the student.\n", + "\n", + " Part 1 - Overview: summary of timing trend, duration trend, and total counts\n", + " of each error type (pitch / missing / extra).\n", + " Part 2 - Note Detail: pitch, timing, duration errors per note\n", + " Part 3 - Chord Detail: errors per chord\n", + "\n", + " Args:\n", + " event_details: list of dicts, output of event_level_feedback()\n", + " response_events: list of event dicts from group_notes_into_events\n", + " ref_events: list of event dicts from group_notes_into_events\n", + " stats: dict, output of compute_stats()\n", + " global_slow_threshold: timing_scale above this triggers \"too slow\" message\n", + " global_fast_threshold: timing_scale below this triggers \"too fast\" message\n", + "\n", + " Returns:\n", + " feedback_message (str)\n", + " \"\"\"\n", + " note_events = [n for n in event_details if n[\"event_type\"] == \"note\"]\n", + " chord_events = [ch for ch in event_details if ch[\"event_type\"] == \"chord\"]\n", + "\n", + " paired_notes = [\n", + " n for n in note_events\n", + " if n[\"operation_type\"] in (\"match\", \"replacement\")\n", + " ]\n", + " paired_chords = [\n", + " ch for ch in chord_events\n", + " if ch[\"operation_type\"] in (\"match\", \"replacement\")\n", + " ]\n", + "\n", + " timing_scale = stats[\"timing_scale\"]\n", + " timing_offset = stats[\"timing_offset\"]\n", + " duration_scale = stats[\"duration_scale\"]\n", + "\n", + " overview_messages = []\n", + " note_detail_messages = []\n", + " chord_detail_messages = []\n", + "\n", + " # ---------- Part 1: Overview ----------\n", + " # Tempo: acceptable / too slow / too fast ---\n", + " timing_pct = abs(timing_scale - 1.0) * 100\n", + " duration_pct = abs(duration_scale - 1.0) * 100\n", + " timing_direction = \"behind\" if timing_scale > 1.0 else \"ahead of\"\n", + " duration_direction = \"longer\" if duration_scale > 1.0 else \"shorter\"\n", + "\n", + " if timing_scale > global_slow_threshold:\n", + " overview_messages.append(\n", + " f\"Overall, your tempo is slower than the reference \"\n", + " f\"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", + " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). \"\n", + " f\"No worries! You will get better when you practice more to get more familiar with it!\"\n", + " )\n", + " elif timing_scale < global_fast_threshold:\n", + " overview_messages.append(\n", + " f\"Overall, your tempo is faster than the reference \"\n", + " f\"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", + " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). \"\n", + " f\"Don't rush even if you are confident in your performance.\" \n", + " f\"Slow down and give each note its full value.\"\n", + " )\n", + " else:\n", + " overview_messages.append(\n", + " f\"Timing: your overall tempo is within an acceptable range. Good job! \"\n", + " f\"The timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", + " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference.\"\n", + " )\n", + "\n", + " # Wrong pitch notes counts\n", + " if stats[\"total_notes_wrong_pitch\"] > 0:\n", + " s = \"is\" if stats[\"total_notes_wrong_pitch\"] == 1 else \"are\"\n", + " note_word = \"note\" if stats[\"total_notes_wrong_pitch\"] == 1 else \"notes\"\n", + " overview_messages.append(\n", + " f\"There {s} {stats['total_notes_wrong_pitch']} {note_word} played with the wrong pitch.\"\n", + " )\n", + " else:\n", + " overview_messages.append(\"There are no pitch errors. Well done!\")\n", + " # Missing notes counts\n", + " if stats[\"total_notes_missing\"] > 0:\n", + " s = \"is\" if stats[\"total_notes_missing\"] == 1 else \"are\"\n", + " note_word = \"note\" if stats[\"total_notes_missing\"] == 1 else \"notes\"\n", + " overview_messages.append(\n", + " f\"There {s} {stats['total_notes_missing']} {note_word} you missed from the reference.\"\n", + " )\n", + " else:\n", + " overview_messages.append(\"There are no missing notes. Great!\")\n", + " # Extra notes counts\n", + " if stats[\"total_notes_extra\"] > 0:\n", + " s = \"is\" if stats[\"total_notes_extra\"] == 1 else \"are\"\n", + " note_word = \"note\" if stats[\"total_notes_extra\"] == 1 else \"notes\"\n", + " overview_messages.append(\n", + " f\"There {s} {stats['total_notes_extra']} extra {note_word} played during practice. \"\n", + " f\"You may need to adjust your fingering or hand position to avoid extra notes.\"\n", + " )\n", + " else:\n", + " overview_messages.append(\"There are no extra notes. Good job!\")\n", + " # Chord errors counts\n", + " if stats[\"total_chords_in_reference\"] > 0:\n", + " total = stats[\"total_chords_in_reference\"]\n", + " correct = stats[\"total_chords_correct\"]\n", + " imperfect = stats[\"total_chords_imperfect\"]\n", + " wrong = stats[\"total_chords_wrong\"]\n", + " overview_messages.append(\n", + " f\"Chords: {correct}/{total} correct, \"\n", + " f\"{imperfect}/{total} imperfect (some notes missing or extra), \"\n", + " f\"{wrong}/{total} completely wrong.\"\n", + " )\n", + " if stats[\"total_chords_missing\"] > 0:\n", + " c_word = \"chord\" if stats[\"total_chords_missing\"] == 1 else \"chords\"\n", + " overview_messages.append(\n", + " f\"{stats['total_chords_missing']} {c_word} missed.\"\n", + " )\n", + " if stats[\"total_chords_extra\"] > 0:\n", + " c_word = \"chord\" if stats[\"total_chords_extra\"] == 1 else \"chords\"\n", + " overview_messages.append(\n", + " f\"{stats['total_chords_extra']} extra {c_word} played.\"\n", + " )\n", + "\n", + " # ---------- Part 2: Note Detail ----------\n", + " # Missing / extra notes\n", + " for n in note_events:\n", + " if n[\"operation_type\"] == \"missing\":\n", + " ref_zero_based = n[\"reference_index\"] - 1\n", + " pitch = ref_events[ref_zero_based][\"notes\"][0][\"pitch\"]\n", + " note_detail_messages.append(\n", + " f\"Note {n['reference_index']} (pitch {pitch}) is missing in your performance.\"\n", + " )\n", + " elif n[\"operation_type\"] == \"extra\":\n", + " res_zero_based = n[\"response_index\"] - 1\n", + " extra = response_events[res_zero_based][\"notes\"][0][\"pitch\"]\n", + " note_detail_messages.append(\n", + " f\"Extra note played: pitch {extra['pitch']} \"\n", + " f\"at t={response_events[res_zero_based]['event_start']:.2f}s \")\n", + " # Pitch errors\n", + " for n in paired_notes:\n", + " if not n[\"pitch_correct\"]:\n", + " ref_zero_based = n[\"reference_index\"] - 1\n", + " res_zero_based = n[\"response_index\"] - 1\n", + " ref_p = ref_events[ref_zero_based][\"notes\"][0][\"pitch\"]\n", + " res_p = response_events[res_zero_based][\"notes\"][0][\"pitch\"]\n", + " note_detail_messages.append(\n", + " f\"Note {n['reference_index']}: wrong pitch — \"\n", + " f\"expected {ref_p}, played {res_p} \"\n", + " f\"({n['pitch_diff']} semitone(s) off).\"\n", + " )\n", + " # Local timing errors - these are residuals after removing the global timing trend\n", + " for n in paired_notes:\n", + " if not n[\"timing_correct\"]:\n", + " note_detail_messages.append(\n", + " f\"Note {n['reference_index']}: timing is off by \"\n", + " f\"{n['timing_abs_diff']:.2f}s \"\n", + " f\"({n['timing_relative_diff'] * 100:.0f}% of the expected note interval), \"\n", + " f\"after accounting for the overall tempo trend.\"\n", + " )\n", + " # Local duration errors — these are residuals after removing the global duration trend\n", + " for n in paired_notes:\n", + " if not n[\"duration_correct\"]:\n", + " direction = \"longer\" if n[\"duration_abs_diff\"] > 0 else \"shorter\"\n", + " duration_pct_err = abs(n[\"duration_relative_diff\"]) * 100\n", + " note_detail_messages.append(\n", + " f\"Note {n['reference_index']}: duration is \"\n", + " f\"{abs(n['duration_abs_diff']):.2f}s {direction} than the reference \"\n", + " f\"({duration_pct_err:.0f}% off) after accounting for the overall duration trend.\"\n", + " )\n", + "\n", + " # ---------- Part 3: Chord Detail ----------\n", + " # Missing / extra chords\n", + " for ch in chord_events:\n", + " if ch[\"operation_type\"] == \"missing\":\n", + " chord_detail_messages.append(\n", + " f\"Chord {ch['reference_index']} ({ch['chord_name_ref']}) \"\n", + " f\"is missing in your performance.\"\n", + " )\n", + " elif ch[\"operation_type\"] == \"extra\":\n", + " chord_detail_messages.append(\n", + " f\"Extra chord played: {ch['chord_name_res']} \"\n", + " f\"at event position {ch['response_index']}.\"\n", + " )\n", + " # Chord accuracy errors\n", + " for ch in paired_chords:\n", + " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] < 1.0:\n", + " accuracy_pct = round(ch[\"chord_accuracy\"] * 100)\n", + " message = (\n", + " f\"Chord {ch['reference_index']} \"\n", + " f\"(expected {ch['chord_name_ref']}, you played {ch['chord_name_res']}): \"\n", + " f\"{accuracy_pct}% accurate. \"\n", + " )\n", + " if ch[\"missing_pitches\"]:\n", + " missing_names = [PITCH_CLASS_NAMES[pc] for pc in ch[\"missing_pitches\"]]\n", + " message = message + \"Missing note(s): \" + \", \".join(missing_names) + \". \"\n", + " if ch[\"extra_pitches\"]:\n", + " extra_names = [PITCH_CLASS_NAMES[pc] for pc in ch[\"extra_pitches\"]]\n", + " message = message + \"Extra note(s) played: \" + \", \".join(extra_names) + \".\"\n", + " chord_detail_messages.append(message)\n", + " # Local timing errors for chords\n", + " for ch in paired_chords:\n", + " if not ch[\"timing_correct\"] and ch[\"timing_relative_diff\"] is not None:\n", + " chord_detail_messages.append(\n", + " f\"Chord {ch['reference_index']}: timing is off by \"\n", + " f\"{ch['timing_abs_diff']:.2f}s \"\n", + " f\"({ch['timing_relative_diff'] * 100:.0f}% of the expected interval).\"\n", + " )\n", + "\n", + " all_messages = [\"Overview: \"] + overview_messages\n", + "\n", + " if note_detail_messages:\n", + " all_messages = all_messages + [\"\", \"Note Detail:\"] + note_detail_messages\n", + " else:\n", + " all_messages = all_messages + [\"\", \"Great performance! No further issues found.\"]\n", + " \n", + " if stats[\"total_chords_in_reference\"] > 0:\n", + " if chord_detail_messages:\n", + " all_messages = (\n", + " all_messages + [\"\", \"Chord Detail:\"] + chord_detail_messages\n", + " )\n", + " else:\n", + " all_messages = all_messages + [\"\", \"All chords played correctly!\"]\n", + " \n", + " return \"\\n\".join(all_messages)\n", + "\n", + "\n", + "# FeedbackResult class\n", + "# ------------------------------------------------------------------------------\n", + "class FeedbackResult:\n", + " \"\"\"\n", + " Container for all outputs of compare_performance_ED().\n", + " Using a class (instead of returning a tuple) makes unit tests much clearer:\n", + " result = compare_performance_ED(response, reference)\n", + " assert result.is_correct == False\n", + " assert result.stats[\"total_notes_missing\"] == 1\n", + " assert \"missing\" in result.feedback_message\n", + "\n", + " Attributes\n", + " ----------\n", + " is_correct : bool\n", + " True only if all notes and chords are perfectly correct.\n", + " stats : dict\n", + " Aggregate counts. See compute_stats() for all keys.\n", + " event_details : list of dicts\n", + " Per-event analysis, one dict per alignment operation.\n", + " Each dict has an \"event_type\" key (\"note\" or \"chord\"),\n", + " plus type-specific fields. See event_level_feedback() for details.\n", + " feedback_message : str\n", + " Human-readable feedback string, ready to display to the student.\n", + " operations : list of dicts\n", + " Raw alignment operations from event_alignment_ED().\n", + " Kept here so visualisation helpers (plot_cost_matrix etc.) can use them.\n", + " D : numpy.ndarray\n", + " Accumulated cost matrix from the alignment step.\n", + " \"\"\"\n", + "\n", + " def __init__(self, is_correct, stats, event_details,\n", + " feedback_message, operations, D):\n", + " self.is_correct = is_correct\n", + " self.stats = stats\n", + " self.event_details = event_details\n", + " self.feedback_message = feedback_message\n", + " self.operations = operations\n", + " self.D = D\n", + " \n", + " def __repr__(self):\n", + " return (\n", + " \"FeedbackResult(is_correct=\" + str(self.is_correct) + \", \"\n", + " \"stats=\" + str(self.stats) + \")\"\n", + " )\n", "\n", "\n", - "def compute_event_cost(event_a, event_b):\n", - " pass" + "def compare_performance_ED(responseMIDI, refMIDI,\n", + " gap_penalty=DEFAULT_GAP_PENALTY,\n", + " timing_relative_threshold=TIMING_RELATIVE_THRESHOLD,\n", + " duration_relative_threshold=DURATION_RELATIVE_THRESHOLD,\n", + " global_slow_threshold=GLOBAL_SLOW_THRESHOLD,\n", + " global_fast_threshold=GLOBAL_FAST_THRESHOLD,\n", + " chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW):\n", + " \"\"\"\n", + " Full pipeline: normalisation -> grouping -> alignment -> global trends\n", + " -> event-level evaluation -> summary statistics -> feedback.\n", + " \n", + " Args:\n", + " responseMIDI: student MIDI dict with key \"notes\"\n", + " refMIDI: reference MIDI dict with key \"notes\"\n", + " gap_penalty: cost of an unaligned event\n", + " timing_relative_threshold: see event_level_feedback()\n", + " duration_relative_threshold: see event_level_feedback()\n", + " global_slow_threshold: see generate_feedback_message()\n", + " global_fast_threshold: see generate_feedback_message()\n", + " chord_onset_window: float (seconds), notes within this window are\n", + " grouped into a chord. Default 0.050 (50ms). Teacher-configurable.\n", + " \n", + " Returns:\n", + " FeedbackResult object containing all analysis results\n", + " \"\"\"\n", + " # Step 0: Normalise start times\n", + " response_notes = normalize_start_times(responseMIDI[\"notes\"])\n", + " ref_notes = normalize_start_times(refMIDI[\"notes\"])\n", + " # Group notes into events (single notes or chords)\n", + " response_events = group_notes_into_events(response_notes, chord_onset_window)\n", + " ref_events = group_notes_into_events(ref_notes, chord_onset_window)\n", + "\n", + " # Step 1: Align events using edit distance\n", + " operations, D = event_alignment_ED(response_events, ref_events, gap_penalty)\n", + "\n", + " # Step 2: Estimate the overall tempo trend\n", + " timing_scale, timing_offset = estimate_global_timing(\n", + " operations, response_events, ref_events\n", + " )\n", + " duration_scale = estimate_global_duration_scale(\n", + " operations, response_events, ref_events\n", + " )\n", + "\n", + " # Step 3: Event-level evaluation\n", + " event_details = event_level_feedback(\n", + " operations, response_events, ref_events,\n", + " timing_scale=timing_scale,\n", + " timing_offset=timing_offset,\n", + " duration_scale=duration_scale,\n", + " timing_relative_threshold=timing_relative_threshold,\n", + " duration_relative_threshold=duration_relative_threshold,\n", + " )\n", + "\n", + " # Step 4: Compute summary statistics\n", + " stats = compute_stats(\n", + " event_details, ref_events,\n", + " timing_scale=timing_scale,\n", + " timing_offset=timing_offset,\n", + " duration_scale=duration_scale,\n", + " )\n", + "\n", + " # Step 5: Generate human-readable feedback\n", + " feedback_message = generate_feedback_message(\n", + " event_details, response_events, ref_events, stats,\n", + " global_slow_threshold=global_slow_threshold,\n", + " global_fast_threshold=global_fast_threshold,\n", + " )\n", + "\n", + " # Step 6: Overall pass/fail judgement\n", + " is_correct = (\n", + " stats[\"total_notes_missing\"] == 0\n", + " and stats[\"total_notes_extra\"] == 0\n", + " and stats[\"total_chords_missing\"] == 0\n", + " and stats[\"total_chords_extra\"] == 0\n", + " and stats[\"pitch_all_correct\"]\n", + " and stats[\"timing_all_correct\"]\n", + " and stats[\"duration_all_correct\"]\n", + " )\n", + "\n", + " return FeedbackResult(\n", + " is_correct=is_correct,\n", + " stats=stats,\n", + " event_details=event_details,\n", + " feedback_message=feedback_message,\n", + " operations=operations,\n", + " D=D,\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "5317bdff", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Overview: \n", + "Timing: your overall tempo is within an acceptable range. Good job! The timing is about 1% behind the reference in general while notes are held about 0% shorter than the reference.\n", + "There are no pitch errors. Well done!\n", + "There are no missing notes. Great!\n", + "There are no extra notes. Good job!\n", + "Chords: 1/3 correct, 1/3 imperfect (some notes missing or extra), 1/3 completely wrong.\n", + "\n", + "Great performance! No further issues found.\n", + "\n", + "Chord Detail:\n", + "Chord 2 (expected F major, you played unknown chord): 83% accurate. Missing note(s): F. \n", + "Chord 3 (expected G major, you played A minor): 0% accurate. Missing note(s): D, G, B. Extra note(s) played: C, E, A.\n" + ] + } + ], + "source": [ + "result = compare_performance_ED(res_chords_data, ref_chords_data)\n", + "print(result.feedback_message)" ] } ], From 2368c9b7def4c009c0d2fcb122aff1b3d2bb1ab4 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Fri, 26 Jun 2026 22:19:22 +0100 Subject: [PATCH 3/8] add some test cases to check output messages --- notebooks/Chord_comparison.ipynb | 165 ++++++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 16 deletions(-) diff --git a/notebooks/Chord_comparison.ipynb b/notebooks/Chord_comparison.ipynb index 35ceae3..a6c6a27 100644 --- a/notebooks/Chord_comparison.ipynb +++ b/notebooks/Chord_comparison.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 24, "id": "db6e1544", "metadata": {}, "outputs": [], @@ -58,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 25, "id": "c04a662c", "metadata": {}, "outputs": [], @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "id": "fc06c1bb", "metadata": {}, "outputs": [], @@ -126,14 +126,14 @@ "DEFAULT_CHORD_ONSET_WINDOW = 0.05\n", "\n", "# Chord template dictionary.\n", - "# Each entry maps a chord name to a frozenset of pitch class intervals,\n", + "# Each entry maps a chord quality name to a set of pitch class intervals,\n", "# where the root note is normalised to 0.\n", "# Only the 4 triad types are included, following Muller (2021) Chapter 5.\n", "CHORD_TEMPLATES = {\n", - " \"major\": frozenset([0, 4, 7]),\n", - " \"minor\": frozenset([0, 3, 7]),\n", - " \"diminished\": frozenset([0, 3, 6]),\n", - " \"augmented\": frozenset([0, 4, 8]),\n", + " \"major\": set([0, 4, 7]),\n", + " \"minor\": set([0, 3, 7]),\n", + " \"diminished\": set([0, 3, 6]),\n", + " \"augmented\": set([0, 4, 8]),\n", "}\n", " \n", "# Pitch class names for human-readable feedback messages.\n", @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 27, "id": "6dbac7c6", "metadata": {}, "outputs": [], @@ -292,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 28, "id": "1f756e2f", "metadata": {}, "outputs": [], @@ -487,7 +487,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 29, "id": "bdb915ea", "metadata": {}, "outputs": [], @@ -579,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 30, "id": "88c61aa5", "metadata": {}, "outputs": [], @@ -1087,7 +1087,7 @@ " if note_detail_messages:\n", " all_messages = all_messages + [\"\", \"Note Detail:\"] + note_detail_messages\n", " else:\n", - " all_messages = all_messages + [\"\", \"Great performance! No further issues found.\"]\n", + " all_messages = all_messages + [\"\", \"All notes(melody part) played correctly!\"]\n", " \n", " if stats[\"total_chords_in_reference\"] > 0:\n", " if chord_detail_messages:\n", @@ -1095,7 +1095,7 @@ " all_messages + [\"\", \"Chord Detail:\"] + chord_detail_messages\n", " )\n", " else:\n", - " all_messages = all_messages + [\"\", \"All chords played correctly!\"]\n", + " all_messages = all_messages + [\"\", \"Great performance! No further issues on chords found.\"]\n", " \n", " return \"\\n\".join(all_messages)\n", "\n", @@ -1237,7 +1237,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 31, "id": "5317bdff", "metadata": {}, "outputs": [ @@ -1252,7 +1252,7 @@ "There are no extra notes. Good job!\n", "Chords: 1/3 correct, 1/3 imperfect (some notes missing or extra), 1/3 completely wrong.\n", "\n", - "Great performance! No further issues found.\n", + "All notes(melody part) played correctly!\n", "\n", "Chord Detail:\n", "Chord 2 (expected F major, you played unknown chord): 83% accurate. Missing note(s): F. \n", @@ -1264,6 +1264,139 @@ "result = compare_performance_ED(res_chords_data, ref_chords_data)\n", "print(result.feedback_message)" ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "7b90ed72", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 11 test cases\n", + " - perfect_melody (short)\n", + " - perfect_chord (short)\n", + " - imperfect_chord_missing_one_note (short)\n", + " - wrong_chord_no_overlap (short)\n", + " - chord_with_extra_note (short)\n", + " - missing_chord (short)\n", + " - mixed_melody_and_chord_perfect (short)\n", + " - mixed_melody_wrong_pitch_imperfect_chord (short)\n", + " - type_mismatch_note_vs_chord (short)\n", + " - chord_timing_off (short)\n", + " - multiple_chords_mixed_accuracy (long)\n" + ] + } + ], + "source": [ + "test_path = os.path.join(dir, \"data\", \"test_cases_chords.json\")\n", + "\n", + "with open(test_path) as f:\n", + " test_cases = json.load(f)\n", + "\n", + "print(f\"Loaded {len(test_cases)} test cases\")\n", + "for case in test_cases:\n", + " print(f\" - {case['case_id']} ({case['length_category']})\")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "81edb080", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Purpose: Two chords: first is perfect C major, second is imperfect A minor (missing E). total_chords_correct = 1, total_chords_imperfect = 1.\n", + "\n", + "is_correct: False\n", + "stats: {'pitch_all_correct': False, 'timing_all_correct': True, 'duration_all_correct': True, 'total_notes_in_reference': 0, 'total_notes_missing': 0, 'total_notes_extra': 0, 'total_notes_wrong_pitch': 0, 'total_notes_wrong_timing': 0, 'total_notes_wrong_duration': 0, 'total_notes_correct': 0, 'timing_scale': 1.0, 'timing_offset': 0.0, 'duration_scale': 1.0, 'total_chords_in_reference': 2, 'total_chords_missing': 0, 'total_chords_extra': 0, 'total_chords_correct': 1, 'total_chords_imperfect': 1, 'total_chords_wrong': 0}\n", + "\n", + "Overview: \n", + "Timing: your overall tempo is within an acceptable range. Good job! The timing is about 0% ahead of the reference in general while notes are held about 0% shorter than the reference.\n", + "There are no pitch errors. Well done!\n", + "There are no missing notes. Great!\n", + "There are no extra notes. Good job!\n", + "Chords: 1/2 correct, 1/2 imperfect (some notes missing or extra), 0/2 completely wrong.\n", + "\n", + "All notes(melody part) played correctly!\n", + "\n", + "Chord Detail:\n", + "Chord 2 (expected A minor, you played unknown chord): 83% accurate. Missing note(s): E. \n" + ] + } + ], + "source": [ + "# Pick one case by id to inspect closely\n", + "case_id = \"multiple_chords_mixed_accuracy\"\n", + "case = next(c for c in test_cases if c[\"case_id\"] == case_id)\n", + "\n", + "print(\"Purpose:\", case[\"purpose\"])\n", + "print()\n", + "\n", + "result = compare_performance_ED(case[\"response\"], case[\"reference\"])\n", + "\n", + "print(\"is_correct:\", result.is_correct)\n", + "print(\"stats:\", result.stats)\n", + "print()\n", + "print(result.feedback_message)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "83a5b636", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[perfect_melody ] is_correct=True pitch_all=True timing_all=True duration_all=True notes_ref=4 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=4 chords_ref=0 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=0 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[perfect_chord ] is_correct=True pitch_all=True timing_all=True duration_all=True notes_ref=0 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=1 chords_imperfect=0 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[imperfect_chord_missing_one_note ] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=0 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=1 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[wrong_chord_no_overlap ] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=0 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=0 chords_wrong=1 timing_scale=1.00 duration_scale=1.00\n", + "[chord_with_extra_note ] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=0 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=1 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[missing_chord ] is_correct=False pitch_all=True timing_all=True duration_all=True notes_ref=1 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=1 chords_ref=1 chords_missing=1 chords_extra=0 chords_correct=0 chords_imperfect=0 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[mixed_melody_and_chord_perfect ] is_correct=True pitch_all=True timing_all=True duration_all=True notes_ref=1 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=1 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=1 chords_imperfect=0 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[mixed_melody_wrong_pitch_imperfect_chord] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=1 missing=0 extra=0 wrong_pitch=1 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=1 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[type_mismatch_note_vs_chord ] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=1 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=1 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=0 chords_imperfect=1 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[chord_timing_off ] is_correct=False pitch_all=True timing_all=False duration_all=True notes_ref=1 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=1 chords_ref=1 chords_missing=0 chords_extra=0 chords_correct=1 chords_imperfect=0 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n", + "[multiple_chords_mixed_accuracy ] is_correct=False pitch_all=False timing_all=True duration_all=True notes_ref=0 missing=0 extra=0 wrong_pitch=0 wrong_timing=0 wrong_dur=0 notes_correct=0 chords_ref=2 chords_missing=0 chords_extra=0 chords_correct=1 chords_imperfect=1 chords_wrong=0 timing_scale=1.00 duration_scale=1.00\n" + ] + } + ], + "source": [ + "for case in test_cases:\n", + " result = compare_performance_ED(case[\"response\"], case[\"reference\"])\n", + " s = result.stats\n", + " print(\n", + " f\"[{case['case_id']:40s}] \"\n", + " f\"is_correct={result.is_correct!s:5} \"\n", + " f\"pitch_all={s['pitch_all_correct']!s:5} \"\n", + " f\"timing_all={s['timing_all_correct']!s:5} \"\n", + " f\"duration_all={s['duration_all_correct']!s:5} \"\n", + " f\"notes_ref={s['total_notes_in_reference']} \"\n", + " f\"missing={s['total_notes_missing']} \"\n", + " f\"extra={s['total_notes_extra']} \"\n", + " f\"wrong_pitch={s['total_notes_wrong_pitch']} \"\n", + " f\"wrong_timing={s['total_notes_wrong_timing']} \"\n", + " f\"wrong_dur={s['total_notes_wrong_duration']} \"\n", + " f\"notes_correct={s['total_notes_correct']} \"\n", + " f\"chords_ref={s['total_chords_in_reference']} \"\n", + " f\"chords_missing={s['total_chords_missing']} \"\n", + " f\"chords_extra={s['total_chords_extra']} \"\n", + " f\"chords_correct={s['total_chords_correct']} \"\n", + " f\"chords_imperfect={s['total_chords_imperfect']} \"\n", + " f\"chords_wrong={s['total_chords_wrong']} \"\n", + " f\"timing_scale={s['timing_scale']:.2f} \"\n", + " f\"duration_scale={s['duration_scale']:.2f}\"\n", + " )" + ] } ], "metadata": { From 3c39bd200bd0b6f3daec6b1895d91c3694f61ced Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Fri, 26 Jun 2026 22:28:21 +0100 Subject: [PATCH 4/8] fix a typeerror in generate feedback message --- notebooks/Chord_comparison.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/Chord_comparison.ipynb b/notebooks/Chord_comparison.ipynb index a6c6a27..1c71ecf 100644 --- a/notebooks/Chord_comparison.ipynb +++ b/notebooks/Chord_comparison.ipynb @@ -579,7 +579,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "id": "88c61aa5", "metadata": {}, "outputs": [], @@ -1010,7 +1010,7 @@ " res_zero_based = n[\"response_index\"] - 1\n", " extra = response_events[res_zero_based][\"notes\"][0][\"pitch\"]\n", " note_detail_messages.append(\n", - " f\"Extra note played: pitch {extra['pitch']} \"\n", + " f\"Extra note played: pitch {extra} \"\n", " f\"at t={response_events[res_zero_based]['event_start']:.2f}s \")\n", " # Pitch errors\n", " for n in paired_notes:\n", From fed552246f5103f5fdf6c73e826e781babcce1dc Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Sat, 27 Jun 2026 15:59:44 +0100 Subject: [PATCH 5/8] fix some typo and package the code, tidy up the notebook --- evaluation_function/compare_MIDI.py | 854 ++++++++++++++----- notebooks/Chord_comparison.ipynb | 1199 +-------------------------- 2 files changed, 673 insertions(+), 1380 deletions(-) diff --git a/evaluation_function/compare_MIDI.py b/evaluation_function/compare_MIDI.py index 363b98e..bde6b46 100644 --- a/evaluation_function/compare_MIDI.py +++ b/evaluation_function/compare_MIDI.py @@ -5,12 +5,14 @@ Pipeline overview (called in order by compare_performance_ED): Step 0 -- normalize_start_times (make first note start at t = 0.0) - Step 1 -- note_alignment_ED (edit-distance alignment) + group_notes_into_events (group simultaneous notes into chords) + Step 1 -- event_alignment_ED (edit-distance alignment) Step 2 -- estimate_global_timing (linear regression for tempo drift) estimate_global_duration_scale - Step 3 -- note_level_feedback (per-note pitch / timing / duration check) + Step 3 -- event_level_feedback (note/chord-level feedback) Step 4 -- compute_stats (summary counts) Step 5 -- generate_feedback_message (human-readable text) + Step 6 -- is_correct (overall pass/fail judgement) """ @@ -34,6 +36,109 @@ GLOBAL_SLOW_THRESHOLD = 1.15 # timing_scale > 1.15 -> "overall too slow" GLOBAL_FAST_THRESHOLD = 0.85 # timing_scale < 0.85 -> "overall too fast" +# Default threshold: notes starting within 50ms are grouped as one chord. +DEFAULT_CHORD_ONSET_WINDOW = 0.05 + +# template and helper functions for chords +# ------------------------------------------------------------------------------ +# Chord template dictionary. +# Each entry maps a chord quality name to a set of pitch class intervals, +# where the root note is normalised to 0. +# Only the 4 triad types are included, following Muller (2021) Chapter 5. +CHORD_TEMPLATES = { + "major": set([0, 4, 7]), + "minor": set([0, 3, 7]), + "diminished": set([0, 3, 6]), + "augmented": set([0, 4, 8]), +} + +# Pitch class names for human-readable feedback messages. +PITCH_CLASS_NAMES = ["C", "C#", "D", "D#", "E", "F", + "F#", "G", "G#", "A", "A#", "B"] + +# Chord helper functions +def get_pitch_class_set(notes): + """ + Convert a list of notes to a set of pitch classes (each pitch mod 12). + + Args: + notes: list of note dicts, each with a "pitch" key. + + Returns: + set of ints, each in range [0, 11]. + """ + return set(note["pitch"] % 12 for note in notes) + +def identify_chord_name(notes): + """ + Identify the chord name (e.g. "C major", "A minor") from a list of notes + by matching their pitch class set against CHORD_TEMPLATES. + + For each candidate root in the pitch class set, normalise all pitch classes + to start at 0 and check against each template. If no match is found, + returns "unknown chord". + + Args: + notes: list of note dicts, each with a "pitch" key. + + Returns: + str: chord name e.g. "C major", or "unknown chord". + """ + pitch_classes = get_pitch_class_set(notes) + + for root_pc in pitch_classes: + normalised = set((pc - root_pc) % 12 for pc in pitch_classes) + for chord_type, template in CHORD_TEMPLATES.items(): + if normalised == template: + root_name = PITCH_CLASS_NAMES[root_pc] + return root_name + " " + chord_type + + return "unknown chord" + + +def compute_chord_accuracy(ref_notes, res_notes): + """ + Compute the chord accuracy score A from McLeod & Rohit (2022), + + A = (C - I + |y|) / (2 * |y|) + + where: + C = |y ∩ y_hat| (correctly played pitch classes) + I = |y_hat - y| (extra pitch classes played) + |y| (number of pitch classes in the reference chord) + + A = 1.0 means perfectly correct. A = 0.0 means nothing correct and + many extra notes are played. + + Args: + ref_notes: list of note dicts for the reference chord. + res_notes: list of note dicts for the response chord. + + Returns: + accuracy: float in [0, 1] + correct_pitches: sorted list of pitch class ints in both chords + missing_pitches: sorted list of pitch class ints in ref + extra_pitches:sorted list of pitch class ints in response + """ + ref_pcs = get_pitch_class_set(ref_notes) + res_pcs = get_pitch_class_set(res_notes) + + correct_pcs = ref_pcs & res_pcs + missing_pcs = ref_pcs - res_pcs + extra_pcs = res_pcs - ref_pcs + + C = len(correct_pcs) + I = len(extra_pcs) + ref_size = len(ref_pcs) + + if ref_size == 0: + accuracy = 0.0 + else: + accuracy = (C - I + ref_size) / (2.0 * ref_size) + accuracy = max(0.0, min(1.0, accuracy)) + + return accuracy, sorted(correct_pcs), sorted(missing_pcs), sorted(extra_pcs) + # Step 0 - make first note start at t = 0.0 # ------------------------------------------------------------------------------ @@ -66,10 +171,89 @@ def normalize_start_times(notes): return shifted_notes +# helper function to build an event dict from a group of notes +def make_event(notes_in_group): + """ + Build a single event dict from a group of notes. + + Args: + notes_in_group: list of one or more note dicts. + + Returns: + event dict with keys "event_type", "notes", "start". example: + { + "event_type": "note" or "chord" depending on the number of notes in the group + "notes": [ + { + "pitch": int + "start": float + "duration": float + }, + ] + "event_start": float, use the start time of the first note in the group + "event_duration": float, use the longest duration among all notes in the group + } + """ + event_type = "note" if len(notes_in_group) == 1 else "chord" + return { + "event_type": event_type, + "notes": notes_in_group, + # use the start time of the first note in the group as the event start + "event_start": notes_in_group[0]["start"], + # use the longest duration among all notes in the group as the event duration + "event_duration": max(note["duration"] for note in notes_in_group), + } + +# group notes into events based on their start times +def group_notes_into_events(notes, chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW): + """ + Group a flat list of notes into events. Notes whose start times fall + within chord_onset_window seconds of each other are placed into the + same event (i.e. treated as a chord). Notes that are not grouped with + any other note form a single-note event. + + Args: + notes: list of dicts, each with keys 'pitch', 'start', 'duration' + chord_onset_window: float, max time difference (seconds) to be grouped + + Returns: + events: list of event dicts, each with keys: + "event_type": "note" if only one note, "chord" if two or more + "notes": list of note dicts belonging to this event + "event_start": float, start time of the first note in the group + "event_duration": float, longest duration among all notes in the group + """ + if len(notes) == 0: + return [] + + # Sort notes by start time first + sorted_notes = sorted(notes, key=lambda n: n["start"]) + + events = [] + current_group = [sorted_notes[0]] + group_start = sorted_notes[0]["start"] + + for note in sorted_notes[1:]: + # Close enough in time: add to current group (chord) + if note["start"] - group_start <= chord_onset_window: + current_group.append(note) + else: + # Too far apart: save current chord, start a new group + event = make_event(current_group) + events.append(event) + current_group = [note] + group_start = note["start"] + + # append the last group + last_event = make_event(current_group) + events.append(last_event) + + return events + # Step 1 -- edit-distance alignment to identify missing/extra notes and pitch errors # ------------------------------------------------------------------------------ -def compute_cost(note1, note2): +def compute_note_cost(note1, note2): """ Cost of aligning (replacing) one note with another, based on pitch. @@ -86,47 +270,93 @@ def compute_cost(note1, note2): return int(abs(note1["pitch"] - note2["pitch"])) -def note_alignment_ED(response_notes, ref_notes, gap_penalty=DEFAULT_GAP_PENALTY): +def compute_event_cost(event1, event2, gap_penalty=DEFAULT_GAP_PENALTY): """ - Align notes using edit distance (ED). + Cost of aligning (substituting) one event with another. + + Rules: + - note vs note: absolute pitch difference + - chord vs chord: Hamming distance on 12-dim pitch class binary vectors + - note vs chord (type mismatch): return gap_penalty so the aligner + treats them as unaligned (insertion + deletion is preferred) + + For chord vs chord, the Hamming distance counts how many of the 12 + pitch classes differ between the two chords (symmetric difference size). + + Args: + event1: event dict (from group_notes_into_events) + event2: event dict (from group_notes_into_events) + gap_penalty: cost of an unaligned event; used as the type-mismatch cost + + Returns: + int: alignment cost >= 0 + """ + type1 = event1["event_type"] + type2 = event2["event_type"] + + # Type mismatch: note vs chord or chord vs note. + # Return gap_penalty so alignment prefers to leave them unmatched. + if type1 != type2: + return gap_penalty + + # Both are single notes: use pitch difference (same as Phase 1) + elif type1 == "note" and type2 == "note": + return compute_note_cost(event1["notes"][0], event2["notes"][0]) + + # Both are chords: Hamming distance on 12-dimensional pitch class vectors. + # i.e. count how many of the 12 pitch classes differ between the two chords. + else: + vec1 = [0] * 12 + vec2 = [0] * 12 + for note in event1["notes"]: + vec1[note["pitch"] % 12] = 1 + for note in event2["notes"]: + vec2[note["pitch"] % 12] = 1 + hamming = sum(1 for i in range(12) if vec1[i] != vec2[i]) + return hamming + + +def event_alignment_ED(response_events, ref_events, gap_penalty=DEFAULT_GAP_PENALTY): + """ + Align events (notes or chords) using edit distance (ED). The ED allows for insertions and deletions, which can be useful for evaluating musical practice containing missing/extra notes. Args: - response_notes: The student's response MIDI notes to evaluate - ref_notes: The reference MIDI note - gap_penalty: cost of leaving a note unaligned (insertion/deletion) + response_events: list of event dicts from group_notes_into_events + ref_events: list of event dicts from group_notes_into_events + gap_penalty: cost of leaving an event unaligned (insertion/deletion) Returns: - operations: list of transformation ops dicts, in order from first note to last: + operations: list of transformation ops dicts, in order from first event to last: {'type': 'match' or 'replacement' or 'missing' or 'extra', 'response_idx': int or None, 'reference_idx': int or None, 'cost': int} D: accumulated cost matrix, shape (N+1, M+1) """ - # the rows of D correspond to response notes - N = len(response_notes) - # the columns of D correspond to reference notes - M = len(ref_notes) + # the rows of D correspond to response events + N = len(response_events) + # the columns of D correspond to reference events + M = len(ref_events) # Build the accumulated cost matrix D of size (N+1 x M+1) D = np.zeros((N + 1, M + 1), dtype=int) - # Boundary conditions: aligning against an empty sequence means every note + # Boundary conditions: aligning against an empty sequence means every event # is unaligned, so the cost is n (or m) times the gap penalty. for n in range(1, N + 1): - D[n, 0] = n * gap_penalty # n extra response notes + D[n, 0] = n * gap_penalty # n extra response events for m in range(1, M + 1): - D[0, m] = m * gap_penalty # m missing ref notes + D[0, m] = m * gap_penalty # m missing ref events # Recursion (accumulated cost / score matrix D): for n in range(1, N + 1): for m in range(1, M + 1): - replace_cost = compute_cost(response_notes[n-1], ref_notes[m-1]) + replace_cost = compute_event_cost(response_events[n-1], ref_events[m-1], gap_penalty) D[n, m] = min( D[n-1, m-1] + replace_cost, # diagonal: match or replacement - D[n-1, m] + gap_penalty, # vertical: extra note response[n-1] + D[n-1, m] + gap_penalty, # vertical: extra event response[n-1] D[n, m-1] + gap_penalty, # horizontal: missing response for ref[m-1] ) @@ -146,7 +376,7 @@ def note_alignment_ED(response_notes, ref_notes, gap_penalty=DEFAULT_GAP_PENALTY m -= 1 # At the leftmost column, only vertical moves possible elif m == 0: - # Extra note response[n-1] (insertion) + # Extra event response[n-1] (insertion) operations.append({ "type": "extra", "response_idx": n - 1, @@ -156,13 +386,15 @@ def note_alignment_ED(response_notes, ref_notes, gap_penalty=DEFAULT_GAP_PENALTY n -= 1 # For all other cases, we can move in any direction (diagonal, vertical, horizontal) else: - replace_cost = compute_cost(response_notes[n - 1], ref_notes[m - 1]) + replace_cost = compute_event_cost(response_events[n - 1], ref_events[m - 1], gap_penalty) diag = D[n - 1, m - 1] + replace_cost # diagonal: match or replacement - up = D[n - 1, m] + gap_penalty # vertical: extra note response[n-1] + up = D[n - 1, m] + gap_penalty # vertical: extra event response[n-1] left = D[n, m - 1] + gap_penalty # horizontal: missing response for ref[m-1] min_cost = min(diag, up, left) # find the minimum cost step + # classify the transformation ops based on the minimum cost step - if min_cost == diag: # Diagonal -> two notes are aligned (match or replacement) + # rule: always prefer diagonal > insertion or deletion + if min_cost == diag: # Diagonal -> two events are aligned (match or replacement) operations.append({ "type": "match" if replace_cost == 0 else "replacement", "response_idx": n - 1, @@ -193,7 +425,7 @@ def note_alignment_ED(response_notes, ref_notes, gap_penalty=DEFAULT_GAP_PENALTY # Step 2 -- estimate_global_timing and estimate_global_duration_scale # ------------------------------------------------------------------------------ -def estimate_global_timing(operations, response_notes, ref_notes): +def estimate_global_timing(operations, response_events, ref_events): """ Estimate the student's overall tempo relative to the reference, by fitting a straight line through the matched note start times: @@ -205,8 +437,8 @@ def estimate_global_timing(operations, response_notes, ref_notes): Args: operations: list of operation dicts (match/replacement/missing/extra) - response_notes: list of note dicts from response - ref_notes: list of note dicts from reference + response_events: list of event dicts from response + ref_events: list of event dicts from reference Returns: scale: float, estimated tempo ratio (1.0 = same speed as reference) @@ -218,8 +450,10 @@ def estimate_global_timing(operations, response_notes, ref_notes): response_starts = [] for op in operations: if op["type"] in ("match", "replacement"): - ref_starts.append(ref_notes[op["reference_idx"]]["start"]) - response_starts.append(response_notes[op["response_idx"]]["start"]) + res = op["response_idx"] + ref = op["reference_idx"] + ref_starts.append(ref_events[ref]["event_start"]) + response_starts.append(response_events[res]["event_start"]) # Not enough points for fitting a meaningful line — assume no drift in tempo. if len(ref_starts) < 3: @@ -233,7 +467,8 @@ def estimate_global_timing(operations, response_notes, ref_notes): return float(scale), float(offset) -def estimate_global_duration_scale(operations, response_notes, ref_notes): + +def estimate_global_duration_scale(operations, response_events, ref_events): """ Estimate the student's overall note-length scale relative to the reference, by fitting a line through the origin: @@ -243,9 +478,9 @@ def estimate_global_duration_scale(operations, response_notes, ref_notes): duration_scale < 1 means notes are held shorter overall Args: - operations: output of note_alignment_ED() - response_notes: list of student note dicts - ref_notes: list of reference note dicts + operations: output of event_alignment_ED() + response_events: list of student event dicts + ref_events: list of reference event dicts Returns: duration_scale (float): estimated duration ratio (1.0 = same as reference) @@ -254,8 +489,10 @@ def estimate_global_duration_scale(operations, response_notes, ref_notes): response_durations = [] for op in operations: if op["type"] in ("match", "replacement"): - ref_durations.append(ref_notes[op["reference_idx"]]["duration"]) - response_durations.append(response_notes[op["response_idx"]]["duration"]) + res = op["response_idx"] + ref = op["reference_idx"] + ref_durations.append(ref_events[ref]["event_duration"]) + response_durations.append(response_events[res]["event_duration"]) if len(ref_durations) < 3: return 1.0 @@ -270,20 +507,26 @@ def estimate_global_duration_scale(operations, response_notes, ref_notes): return duration_scale -# Step 3 -- note_level_feedback +# Step 3 -- event_level_feedback # ------------------------------------------------------------------------------ -def note_level_feedback(operations, response_notes, ref_notes, - timing_scale=1.0, timing_offset=0.0, duration_scale=1.0, - timing_relative_threshold=TIMING_RELATIVE_THRESHOLD, - duration_relative_threshold=DURATION_RELATIVE_THRESHOLD): +def event_level_feedback(operations, response_events, ref_events, + timing_scale=1.0, timing_offset=0.0, + duration_scale=1.0, + timing_relative_threshold=TIMING_RELATIVE_THRESHOLD, + duration_relative_threshold=DURATION_RELATIVE_THRESHOLD): """ - Analyse each aligned note pair (or missing/extra event) and return a list - of note result dicts. + Analyse each aligned event pair (or missing/extra event) and return a + list of event result dicts. + For single-note events, pitch evaluation uses absolute pitch difference. + For chord events, pitch evaluation uses the chord accuracy metric A: + A = (C - I + |y|) / (2 * |y|) + where C = correctly played pitch classes, I = extra pitch classes played, + |y| = number of pitch classes in the reference chord. Args: operations: list of op dicts (match/replacement/missing/extra) - response_notes: list of note dicts from response - ref_notes: list of note dicts from reference + response_events: list of event dicts from response + ref_events: list of event dicts from reference timing_scale: float, estimated tempo ratio (1.0 = same speed as reference) timing_offset: float (seconds), estimated constant time shift duration_scale: float, estimated overall duration ratio @@ -291,61 +534,111 @@ def note_level_feedback(operations, response_notes, ref_notes, duration_relative_threshold: float, relative tolerance for duration correctness Returns: - note_level_results : list of dicts, each dict contains: - "reference_index" -> int (1-based) or None if operation_type = extra - "response_index" -> int (1-based) or None if operation_type = missing - "operation_type" -> str: "match", "replacement", "missing", or "extra" + event_level_results : list of dicts, each dict contains: + + For note events, each dict has: + "event_type" -> "note" + "reference_index" -> int (1-based) or None + "response_index" -> int (1-based) or None + "operation_type" -> "match", "replacement", "missing", "extra" "pitch_correct" -> bool - "pitch_diff" -> int (semitones) or None if operation_type = missing/extra - "timing_correct" -> bool - "timing_abs_diff" -> float (seconds) or None if operation_type = missing/extra - “timing_relative_diff” -> float (seconds) or None if operation_type = missing/extra + "pitch_diff" -> int (semitones) or None + "timing_correct" -> bool + "timing_abs_diff" -> float (seconds) or None + "timing_relative_diff" -> float or None "duration_correct" -> bool - "duration_abs_diff"-> float (seconds) or None if operation_type = missing/extra - "duration_relative_diff" -> float (seconds) or None if operation_type = missing/extra + "duration_abs_diff" -> float (seconds) or None + "duration_relative_diff" -> float or None + + For chord events, each dict has: + "event_type" -> "chord" + "reference_index" -> int (1-based) or None + "response_index" -> int (1-based) or None + "operation_type" -> "match", "replacement", "missing", "extra" + "chord_name_ref" -> str e.g. "C major", or None + "chord_name_res" -> str e.g. "C minor", or None + "chord_accuracy" -> float (0 to 1) or None + "correct_pitches" -> list of pitch class ints or None + "missing_pitches" -> list of pitch class ints or None + "extra_pitches" -> list of pitch class ints or None + "timing_correct" -> bool + "timing_abs_diff" -> float (seconds) or None + "timing_relative_diff" -> float or None + "duration_correct" -> bool + "duration_abs_diff" -> float (seconds) or None + "duration_relative_diff" -> float or None """ # Compute IOI for each reference note: ioi[m] = ref_notes[m]["start"] - ref_notes[m-1]["start"] # floor at 0.05s to avoid division by zero issues - ref_ioi = [None] * len(ref_notes) - for m in range(1, len(ref_notes)): - interval = ref_notes[m]["start"] - ref_notes[m - 1]["start"] + ref_ioi = [None] * len(ref_events) + for m in range(1, len(ref_events)): + interval = ref_events[m]["event_start"] - ref_events[m - 1]["event_start"] ref_ioi[m] = max(interval, 0.05) - note_level_results = [] + event_level_results = [] for op in operations: res_idx = op["response_idx"] ref_idx = op["reference_idx"] op_type = op["type"] + # Determine event type from whichever side is available + if ref_idx is not None: + event_type = ref_events[ref_idx]["event_type"] + else: + event_type = response_events[res_idx]["event_type"] + # Missing/extra notes: no pitch/timing/duration comparison is possible, # so all the numeric fields are set to None. if op_type in ("missing", "extra"): - note_level_results.append({ - "reference_index": (ref_idx + 1) if ref_idx is not None else None, - "response_index": (res_idx + 1) if res_idx is not None else None, - "operation_type": op_type, - "pitch_correct": False, - "pitch_diff": None, - "timing_correct": False, - "timing_abs_diff": None, - "timing_relative_diff": None, - "duration_correct": False, - "duration_abs_diff": None, - "duration_relative_diff": None, - }) + if event_type == "note": + event_level_results.append({ + "event_type": "note", + "reference_index": (ref_idx + 1) if ref_idx is not None else None, + "response_index": (res_idx + 1) if res_idx is not None else None, + "operation_type": op_type, + "pitch_correct": False, + "pitch_diff": None, + "timing_correct": False, + "timing_abs_diff": None, + "timing_relative_diff": None, + "duration_correct": False, + "duration_abs_diff": None, + "duration_relative_diff": None, + }) + else: + if ref_idx is not None: + chord_name = identify_chord_name(ref_events[ref_idx]["notes"]) + else: + chord_name = identify_chord_name( + response_events[res_idx]["notes"] + ) + event_level_results.append({ + "event_type": "chord", + "reference_index": (ref_idx + 1) if ref_idx is not None else None, + "response_index": (res_idx + 1) if res_idx is not None else None, + "operation_type": op_type, + "chord_name_ref": chord_name if op_type == "missing" else None, + "chord_name_res": chord_name if op_type == "extra" else None, + "chord_accuracy": None, + "correct_pitches": None, + "missing_pitches": None, + "extra_pitches": None, + "timing_correct": False, + "timing_abs_diff": None, + "timing_relative_diff": None, + "duration_correct": False, + "duration_abs_diff": None, + "duration_relative_diff": None, + }) else: - # Matched (aligned) note pair - res_note = response_notes[res_idx] - ref_note = ref_notes[ref_idx] - - ## Pitch - pitch_diff = int(abs(res_note["pitch"] - ref_note["pitch"])) - pitch_correct = (pitch_diff == 0) + # Aligned event pair (match or replacement) + res_event = response_events[res_idx] + ref_event = ref_events[ref_idx] # Timing — residual after removing the global tempo trend - predicted_start = timing_scale * ref_note["start"] + timing_offset - timing_abs_diff = abs(res_note["start"] - predicted_start) + predicted_start = timing_scale * ref_event["event_start"] + timing_offset + timing_abs_diff = abs(res_event["event_start"] - predicted_start) if ref_idx == 0: # First note will start at 0, so no difference. timing_relative_diff = None @@ -356,46 +649,77 @@ def note_level_feedback(operations, response_notes, ref_notes, timing_correct = (timing_relative_diff <= timing_relative_threshold) # Duration — residual after removing the global duration-scale trend - predicted_duration = duration_scale * ref_note["duration"] - duration_abs_diff = res_note["duration"] - predicted_duration - ref_dur = max(ref_note["duration"], 0.05) # floor at 0.05s to avoid division by zero issues + predicted_duration = duration_scale * ref_event["event_duration"] + duration_abs_diff = abs(res_event["event_duration"] - predicted_duration) + ref_dur = max(ref_event["event_duration"], 0.05) # floor at 0.05s to avoid division by zero issues duration_relative_diff = duration_abs_diff / ref_dur - duration_correct = (abs(duration_relative_diff) <= duration_relative_threshold) - - note_level_results.append({ - "reference_index": ref_idx + 1, - "response_index": res_idx + 1, - "operation_type": op_type, - "pitch_correct": pitch_correct, - "pitch_diff": pitch_diff, - "timing_correct": timing_correct, - "timing_abs_diff": timing_abs_diff, - "timing_relative_diff": timing_relative_diff, - "duration_correct": duration_correct, - "duration_abs_diff": duration_abs_diff, - "duration_relative_diff": duration_relative_diff, - }) + duration_correct = (duration_relative_diff <= duration_relative_threshold) + if event_type == "note": + # For single-note events, pitch correctness is based on absolute pitch difference. + pitch1 = res_event["notes"][0]["pitch"] + pitch2 = ref_event["notes"][0]["pitch"] + pitch_diff = int(abs(pitch1 - pitch2)) + event_level_results.append({ + "event_type": "note", + "reference_index": ref_idx + 1, + "response_index": res_idx + 1, + "operation_type": op_type, + "pitch_correct": (pitch_diff == 0), + "pitch_diff": pitch_diff, + "timing_correct": timing_correct, + "timing_abs_diff": timing_abs_diff, + "timing_relative_diff": timing_relative_diff, + "duration_correct": duration_correct, + "duration_abs_diff": duration_abs_diff, + "duration_relative_diff": duration_relative_diff, + }) + else: + # For chord events, pitch correctness is based on the chord accuracy metric. + accuracy, correct_pcs, missing_pcs, extra_pcs = ( + compute_chord_accuracy( + ref_event["notes"], res_event["notes"] + ) + ) + event_level_results.append({ + "event_type": "chord", + "reference_index": ref_idx + 1, + "response_index": res_idx + 1, + "operation_type": op_type, + "chord_name_ref": identify_chord_name(ref_event["notes"]), + "chord_name_res": identify_chord_name(res_event["notes"]), + "chord_accuracy": accuracy, + "correct_pitches": correct_pcs, + "missing_pitches": missing_pcs, + "extra_pitches": extra_pcs, + "timing_correct": timing_correct, + "timing_abs_diff": timing_abs_diff, + "timing_relative_diff": timing_relative_diff, + "duration_correct": duration_correct, + "duration_abs_diff": duration_abs_diff, + "duration_relative_diff": duration_relative_diff, + }) - return note_level_results + return event_level_results # Step 4 -- compute_stats # ------------------------------------------------------------------------------ -def compute_stats(note_details, ref_notes, timing_scale=1.0, +def compute_stats(event_level_results, ref_events, timing_scale=1.0, timing_offset=0.0, duration_scale=1.0): """ - Compute summary counts and correctness booleans from note-level feedback. + Compute summary counts and correctness booleans from event-level feedback. Args: - note_details: list of dicts, output of note_level_feedback() - ref_notes: list of reference note dicts + event_level_results: list of dicts, output of event_level_feedback() + ref_events: list of reference event dicts timing_scale: float, from estimate_global_timing() timing_offset: float, from estimate_global_timing() duration_scale: float, from estimate_global_duration_scale() Returns: stats: dict with keys: - "pitch_all_correct" -> bool + "pitch_all_aligned_correct" -> bool, True if all note pitches are + correct AND all chord accuracies are 1.0 "timing_all_correct" -> bool "duration_all_correct" -> bool "total_notes_in_reference" -> int @@ -408,34 +732,79 @@ def compute_stats(note_details, ref_notes, timing_scale=1.0, "timing_scale" -> float "timing_offset" -> float "duration_scale" -> float + "total_chords_in_reference" -> int + "total_chords_missing" -> int + "total_chords_extra" -> int + "total_chords_correct" -> int (accuracy == 1.0) + "total_chords_imperfect" -> int (0.0 < accuracy < 1.0) + "total_chords_wrong" -> int (accuracy == 0.0) """ - paired = [n for n in note_details - if n["operation_type"] in ("match", "replacement")] + note_events = [n for n in event_level_results if n["event_type"] == "note"] + chord_events = [ch for ch in event_level_results if ch["event_type"] == "chord"] + + ref_note_count = sum(1 for n in ref_events if n["event_type"] == "note") + ref_chord_count = sum(1 for ch in ref_events if ch["event_type"] == "chord") + + paired_notes = [ + n for n in note_events + if n["operation_type"] in ("match", "replacement") + ] + paired_chords = [ + ch for ch in chord_events + if ch["operation_type"] in ("match", "replacement") + ] + all_paired = [ + e for e in event_level_results + if e["operation_type"] in ("match", "replacement") + ] stats = { - "pitch_all_correct": all(n["pitch_correct"] for n in paired), - "timing_all_correct": all(n["timing_correct"] for n in paired), - "duration_all_correct": all(n["duration_correct"] for n in paired), - "total_notes_in_reference": len(ref_notes), - "total_notes_missing": sum(1 for n in note_details if n["operation_type"] == "missing"), - "total_notes_extra": sum(1 for n in note_details if n["operation_type"] == "extra"), - "total_notes_wrong_pitch": sum(1 for n in paired if not n["pitch_correct"]), - "total_notes_wrong_timing": sum(1 for n in paired if not n["timing_correct"]), - "total_notes_wrong_duration": sum(1 for n in paired if not n["duration_correct"]), - "total_notes_correct": sum(1 for n in paired + "pitch_all_aligned_correct": ( + all(n["pitch_correct"] for n in paired_notes) + and all( + ch["chord_accuracy"] == 1.0 + for ch in paired_chords + if ch["chord_accuracy"] is not None + ) + ), + "timing_all_correct": all(n["timing_correct"] for n in all_paired), + "duration_all_correct": all(n["duration_correct"] for n in all_paired), + "total_notes_in_reference": ref_note_count, + "total_notes_missing": sum(1 for n in note_events if n["operation_type"] == "missing"), + "total_notes_extra": sum(1 for n in note_events if n["operation_type"] == "extra"), + "total_notes_wrong_pitch": sum(1 for n in paired_notes if not n["pitch_correct"]), + "total_notes_wrong_timing": sum(1 for n in paired_notes if not n["timing_correct"]), + "total_notes_wrong_duration": sum(1 for n in paired_notes if not n["duration_correct"]), + "total_notes_correct": sum(1 for n in paired_notes if n["pitch_correct"] and n["timing_correct"] and n["duration_correct"] ), "timing_scale": timing_scale, "timing_offset": timing_offset, "duration_scale": duration_scale, + "total_chords_in_reference": ref_chord_count, + "total_chords_missing": sum(1 for ch in chord_events if ch["operation_type"] == "missing"), + "total_chords_extra": sum(1 for ch in chord_events if ch["operation_type"] == "extra"), + "total_chords_correct": sum( + 1 for ch in paired_chords + if ch["chord_accuracy"] is not None and ch["chord_accuracy"] == 1.0 + ), + "total_chords_imperfect": sum( + 1 for ch in paired_chords + if ch["chord_accuracy"] is not None + and ch["chord_accuracy"] > 0.0 + and ch["chord_accuracy"] < 1.0 + ), + "total_chords_wrong": sum( + 1 for ch in paired_chords + if ch["chord_accuracy"] is not None and ch["chord_accuracy"] == 0.0 + ), } - return stats # Step 5 -- generate_feedback_message # ------------------------------------------------------------------------------ -def generate_feedback_message(note_details, response_notes, ref_notes, stats, +def generate_feedback_message(event_details, response_events, ref_events, stats, global_slow_threshold=GLOBAL_SLOW_THRESHOLD, global_fast_threshold=GLOBAL_FAST_THRESHOLD): """ @@ -443,12 +812,13 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, Part 1 - Overview: summary of timing trend, duration trend, and total counts of each error type (pitch / missing / extra). - Part 2 - Detail: indicate exactly which notes have which problems. + Part 2 - Note Detail: pitch, timing, duration errors per note + Part 3 - Chord Detail: errors per chord Args: - note_details: list of dicts, output of note_level_feedback() - response_notes: list of student note dicts - ref_notes: list of reference note dicts + event_details: list of dicts, output of event_level_feedback() + response_events: list of event dicts from group_notes_into_events + ref_events: list of event dicts from group_notes_into_events stats: dict, output of compute_stats() global_slow_threshold: timing_scale above this triggers "too slow" message global_fast_threshold: timing_scale below this triggers "too fast" message @@ -456,37 +826,56 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, Returns: feedback_message (str) """ - paired = [] - for n in note_details: - if n["operation_type"] in ("match", "replacement"): - paired.append(n) - - timing_scale = stats["timing_scale"] - timing_offset = stats["timing_offset"] + note_events = [n for n in event_details if n["event_type"] == "note"] + chord_events = [ch for ch in event_details if ch["event_type"] == "chord"] + + paired_notes = [ + n for n in note_events + if n["operation_type"] in ("match", "replacement") + ] + paired_chords = [ + ch for ch in chord_events + if ch["operation_type"] in ("match", "replacement") + ] + + timing_scale = stats["timing_scale"] + timing_offset = stats["timing_offset"] duration_scale = stats["duration_scale"] overview_messages = [] - detail_messages = [] + note_detail_messages = [] + chord_detail_messages = [] # ---------- Part 1: Overview ---------- - # Tempo: acceptable / too slow / too fast --- + # Tempo: acceptable / too slow / too fast timing_pct = abs(timing_scale - 1.0) * 100 duration_pct = abs(duration_scale - 1.0) * 100 - timing_direction = "behind" if timing_scale > 1.0 else "ahead of" - duration_direction = "longer" if duration_scale > 1.0 else "shorter" + if timing_scale > 1: + timing_direction = "behind" + elif timing_scale < 1: + timing_direction = "ahead of" + else: + timing_direction = "the same as" + + if duration_scale > 1: + duration_direction = "longer than" + elif duration_scale < 1: + duration_direction = "shorter than" + else: + duration_direction = "the same as" if timing_scale > global_slow_threshold: overview_messages.append( f"Overall, your tempo is slower than the reference " f"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while " - f"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). " + f"notes are held about {duration_pct:.0f}% {duration_direction} the reference). " f"No worries! You will get better when you practice more to get more familiar with it!" ) elif timing_scale < global_fast_threshold: overview_messages.append( f"Overall, your tempo is faster than the reference " f"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while " - f"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). " + f"notes are held about {duration_pct:.0f}% {duration_direction} the reference). " f"Don't rush even if you are confident in your performance." f"Slow down and give each note its full value." ) @@ -497,7 +886,7 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, f"notes are held about {duration_pct:.0f}% {duration_direction} than the reference." ) - # Wrong pitch counts + # Wrong notes pitch counts if stats["total_notes_wrong_pitch"] > 0: s = "is" if stats["total_notes_wrong_pitch"] == 1 else "are" note_word = "note" if stats["total_notes_wrong_pitch"] == 1 else "notes" @@ -506,7 +895,7 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, ) else: overview_messages.append("There are no pitch errors. Well done!") - # Missing counts + # Missing notes counts if stats["total_notes_missing"] > 0: s = "is" if stats["total_notes_missing"] == 1 else "are" note_word = "note" if stats["total_notes_missing"] == 1 else "notes" @@ -515,7 +904,7 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, ) else: overview_messages.append("There are no missing notes. Great!") - # Extra counts + # Extra notes counts if stats["total_notes_extra"] > 0: s = "is" if stats["total_notes_extra"] == 1 else "are" note_word = "note" if stats["total_notes_extra"] == 1 else "notes" @@ -525,63 +914,127 @@ def generate_feedback_message(note_details, response_notes, ref_notes, stats, ) else: overview_messages.append("There are no extra notes. Good job!") + # Chord errors counts + if stats["total_chords_in_reference"] > 0: + total = stats["total_chords_in_reference"] + correct = stats["total_chords_correct"] + imperfect = stats["total_chords_imperfect"] + wrong = stats["total_chords_wrong"] + overview_messages.append( + f"Chords: {correct}/{total} correct, " + f"{imperfect}/{total} imperfect (some notes missing or extra), " + f"{wrong}/{total} completely wrong." + ) + if stats["total_chords_missing"] > 0: + c_word = "chord" if stats["total_chords_missing"] == 1 else "chords" + overview_messages.append( + f"{stats['total_chords_missing']} {c_word} missed." + ) + if stats["total_chords_extra"] > 0: + c_word = "chord" if stats["total_chords_extra"] == 1 else "chords" + overview_messages.append( + f"{stats['total_chords_extra']} extra {c_word} played." + ) - # ---------- Part 2: Detail ---------- + # ---------- Part 2: Note Detail ---------- # Missing / extra notes - for n in note_details: + for n in note_events: if n["operation_type"] == "missing": ref_zero_based = n["reference_index"] - 1 - pitch = ref_notes[ref_zero_based]["pitch"] - detail_messages.append( + pitch = ref_events[ref_zero_based]["notes"][0]["pitch"] + note_detail_messages.append( f"Note {n['reference_index']} (pitch {pitch}) is missing in your performance." ) elif n["operation_type"] == "extra": res_zero_based = n["response_index"] - 1 - extra = response_notes[res_zero_based] - detail_messages.append( - f"Extra note played: pitch {extra['pitch']} at t={extra['start']:.2f}s ") - + extra = response_events[res_zero_based]["notes"][0]["pitch"] + note_detail_messages.append( + f"Extra note played: pitch {extra} " + f"at t={response_events[res_zero_based]['event_start']:.2f}s ") # Pitch errors - for n in paired: + for n in paired_notes: if not n["pitch_correct"]: ref_zero_based = n["reference_index"] - 1 res_zero_based = n["response_index"] - 1 - ref_p = ref_notes[ref_zero_based]["pitch"] - res_p = response_notes[res_zero_based]["pitch"] - detail_messages.append( + ref_p = ref_events[ref_zero_based]["notes"][0]["pitch"] + res_p = response_events[res_zero_based]["notes"][0]["pitch"] + note_detail_messages.append( f"Note {n['reference_index']}: wrong pitch — " f"expected {ref_p}, played {res_p} " f"({n['pitch_diff']} semitone(s) off)." ) - # Local timing errors - these are residuals after removing the global timing trend - for n in paired: + for n in paired_notes: if not n["timing_correct"]: - detail_messages.append( - f"Note {n['reference_index']}: timing is off by {n['timing_abs_diff']:.2f}s " + note_detail_messages.append( + f"Note {n['reference_index']}: timing is off by " + f"{n['timing_abs_diff']:.2f}s " f"({n['timing_relative_diff'] * 100:.0f}% of the expected note interval), " f"after accounting for the overall tempo trend." ) - # Local duration errors — these are residuals after removing the global duration trend - for n in paired: + for n in paired_notes: if not n["duration_correct"]: direction = "longer" if n["duration_abs_diff"] > 0 else "shorter" - ref_zero_based = n["reference_index"] - 1 - ref_dur = ref_notes[ref_zero_based]["duration"] - duration_pct = abs(n["duration_relative_diff"]) * 100 - detail_messages.append( - f"Note {n['reference_index']}: duration is {abs(n['duration_abs_diff']):.2f}s " - f"{direction} than the reference (i.e. " - f"{duration_pct:.0f}% off) after accounting for the overall duration trend " + duration_pct_err = abs(n["duration_relative_diff"]) * 100 + note_detail_messages.append( + f"Note {n['reference_index']}: duration is " + f"{abs(n['duration_abs_diff']):.2f}s {direction} than the reference " + f"({duration_pct_err:.0f}% off) after accounting for the overall duration trend." + ) + + # ---------- Part 3: Chord Detail ---------- + # Missing / extra chords + for ch in chord_events: + if ch["operation_type"] == "missing": + chord_detail_messages.append( + f"Chord {ch['reference_index']} ({ch['chord_name_ref']}) " + f"is missing in your performance." + ) + elif ch["operation_type"] == "extra": + chord_detail_messages.append( + f"Extra chord played: {ch['chord_name_res']} " + f"at event position {ch['response_index']}." + ) + # Chord accuracy errors + for ch in paired_chords: + if ch["chord_accuracy"] is not None and ch["chord_accuracy"] < 1.0: + accuracy_pct = round(ch["chord_accuracy"] * 100) + message = ( + f"Chord {ch['reference_index']} " + f"(expected {ch['chord_name_ref']}, you played {ch['chord_name_res']}): " + f"{accuracy_pct}% accurate. " + ) + if ch["missing_pitches"]: + missing_names = [PITCH_CLASS_NAMES[pc] for pc in ch["missing_pitches"]] + message = message + "Missing note(s): " + ", ".join(missing_names) + ". " + if ch["extra_pitches"]: + extra_names = [PITCH_CLASS_NAMES[pc] for pc in ch["extra_pitches"]] + message = message + "Extra note(s) played: " + ", ".join(extra_names) + "." + chord_detail_messages.append(message) + # Local timing errors for chords + for ch in paired_chords: + if not ch["timing_correct"] and ch["timing_relative_diff"] is not None: + chord_detail_messages.append( + f"Chord {ch['reference_index']}: timing is off by " + f"{ch['timing_abs_diff']:.2f}s " + f"({ch['timing_relative_diff'] * 100:.0f}% of the expected interval)." ) all_messages = ["Overview: "] + overview_messages - if detail_messages: - all_messages = all_messages + ["", "Detail: "] + detail_messages + if note_detail_messages: + all_messages = all_messages + ["", "Note Detail:"] + note_detail_messages else: - all_messages = all_messages + ["", "Great performance! No further issues found."] + all_messages = all_messages + ["", "All melody notes played correctly!!"] + + if stats["total_chords_in_reference"] > 0: + if chord_detail_messages: + all_messages = ( + all_messages + ["", "Chord Detail:"] + chord_detail_messages + ) + else: + all_messages = all_messages + ["", "Great performance! No further issues on chords found."] return "\n".join(all_messages) @@ -600,31 +1053,32 @@ class FeedbackResult: Attributes ---------- is_correct : bool - True only if every note is perfectly matched on pitch, timing, and duration. + True only if every note/chord has correct pitch, + timing and duration, with no missing or extra events. stats : dict - Aggregate counts — see compute_stats() for the full key list. - note_details : list of dicts - Per-note analysis, one dict per alignment operation. - Each dict has the keys described in note_level_feedback(). + Aggregate counts. See compute_stats() for all keys. + event_details : list of dicts + Per-event analysis, one dict per alignment operation. + Each dict has an "event_type" key ("note" or "chord"), + plus type-specific fields. See event_level_feedback() for details. feedback_message : str Human-readable feedback string, ready to display to the student. - see generate_feedback_message() for details. operations : list of dicts - Raw alignment operations from note_alignment_ED(). + Raw alignment operations from event_alignment_ED(). Kept here so visualisation helpers (plot_cost_matrix etc.) can use them. D : numpy.ndarray Accumulated cost matrix from the alignment step. """ - def __init__(self, is_correct, stats, note_details, + def __init__(self, is_correct, stats, event_details, feedback_message, operations, D): self.is_correct = is_correct self.stats = stats - self.note_details = note_details + self.event_details = event_details self.feedback_message = feedback_message self.operations = operations self.D = D - + def __repr__(self): return ( "FeedbackResult(is_correct=" + str(self.is_correct) + ", " @@ -639,41 +1093,47 @@ def compare_performance_ED(responseMIDI, refMIDI, timing_relative_threshold=TIMING_RELATIVE_THRESHOLD, duration_relative_threshold=DURATION_RELATIVE_THRESHOLD, global_slow_threshold=GLOBAL_SLOW_THRESHOLD, - global_fast_threshold=GLOBAL_FAST_THRESHOLD): + global_fast_threshold=GLOBAL_FAST_THRESHOLD, + chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW): """ - Full pipeline: normalisation -> alignment -> estimate global trends - -> note-level evaluation -> summary statistics -> feedback. - + Full pipeline: normalisation -> grouping -> alignment -> global trends + -> event-level evaluation -> summary statistics -> feedback. + Args: responseMIDI: student MIDI dict with key "notes" refMIDI: reference MIDI dict with key "notes" - gap_penalty: cost of an unaligned note - timing_relative_threshold: see note_level_feedback() - duration_relative_threshold: see note_level_feedback() + gap_penalty: cost of an unaligned event + timing_relative_threshold: see event_level_feedback() + duration_relative_threshold: see event_level_feedback() global_slow_threshold: see generate_feedback_message() global_fast_threshold: see generate_feedback_message() - + chord_onset_window: float (seconds), notes within this window are + grouped into a chord. Default 0.050 (50ms). Teacher-configurable. + Returns: FeedbackResult object containing all analysis results """ # Step 0: Normalise start times response_notes = normalize_start_times(responseMIDI["notes"]) - ref_notes = normalize_start_times(refMIDI["notes"]) + ref_notes = normalize_start_times(refMIDI["notes"]) + # Group notes into events (single notes or chords) + response_events = group_notes_into_events(response_notes, chord_onset_window) + ref_events = group_notes_into_events(ref_notes, chord_onset_window) - # Step 1: Align notes using edit distance - operations, D = note_alignment_ED(response_notes, ref_notes, gap_penalty) + # Step 1: Align events using edit distance + operations, D = event_alignment_ED(response_events, ref_events, gap_penalty) - # Step 2: Estimate the overall tempo trend + # Step 2: Estimate the overall tempo trend (global timing and duration trends) timing_scale, timing_offset = estimate_global_timing( - operations, response_notes, ref_notes + operations, response_events, ref_events ) duration_scale = estimate_global_duration_scale( - operations, response_notes, ref_notes + operations, response_events, ref_events ) - # Step 3: Note-level evaluation - note_details = note_level_feedback( - operations, response_notes, ref_notes, + # Step 3: Event details feedback + event_details = event_level_feedback( + operations, response_events, ref_events, timing_scale=timing_scale, timing_offset=timing_offset, duration_scale=duration_scale, @@ -683,15 +1143,15 @@ def compare_performance_ED(responseMIDI, refMIDI, # Step 4: Compute summary statistics stats = compute_stats( - note_details, ref_notes, + event_details, ref_events, timing_scale=timing_scale, timing_offset=timing_offset, duration_scale=duration_scale, ) - # Step 5: Generate the human-readable feedback text + # Step 5: Generate human-readable feedback feedback_message = generate_feedback_message( - note_details, response_notes, ref_notes, stats, + event_details, response_events, ref_events, stats, global_slow_threshold=global_slow_threshold, global_fast_threshold=global_fast_threshold, ) @@ -700,7 +1160,9 @@ def compare_performance_ED(responseMIDI, refMIDI, is_correct = ( stats["total_notes_missing"] == 0 and stats["total_notes_extra"] == 0 - and stats["pitch_all_correct"] + and stats["total_chords_missing"] == 0 + and stats["total_chords_extra"] == 0 + and stats["pitch_all_aligned_correct"] and stats["timing_all_correct"] and stats["duration_all_correct"] ) @@ -708,7 +1170,7 @@ def compare_performance_ED(responseMIDI, refMIDI, return FeedbackResult( is_correct=is_correct, stats=stats, - note_details=note_details, + event_details=event_details, feedback_message=feedback_message, operations=operations, D=D, diff --git a/notebooks/Chord_comparison.ipynb b/notebooks/Chord_comparison.ipynb index 1c71ecf..32db8fe 100644 --- a/notebooks/Chord_comparison.ipynb +++ b/notebooks/Chord_comparison.ipynb @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 35, "id": "db6e1544", "metadata": {}, "outputs": [], @@ -45,1199 +45,30 @@ " res_chords_data = json.load(f2)" ] }, - { - "cell_type": "markdown", - "id": "ae192f73", - "metadata": {}, - "source": [ - "chord accuracy metric: https://arxiv.org/pdf/2201.05244\n", - "\n", - "consider chord as unordered set of pitch classes\n", - "https://www.researchgate.net/profile/Pierre-Hanna/publication/265984596_A_Survey_Of_Chord_Distances_With_Comparison_For_Chord_Analysis/links/555c4c7808aec5ac2232b158/A-Survey-Of-Chord-Distances-With-Comparison-For-Chord-Analysis.pdf" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "c04a662c", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "\n", - "# Default thresholds / parameters\n", - "# Teachers can override any of these via the params dict in evaluation_function.\n", - "# ------------------------------------------------------------------------------\n", - "# Gap penalty: cost of leaving a note unaligned (insertion/deletion)\n", - "DEFAULT_GAP_PENALTY = 6\n", - "\n", - "# Timing: |response_start - predicted_start| / IOI must be below this.\n", - "# e.g. 0.20 means the start can be off by up to 20% of the inter-onset interval.\n", - "TIMING_RELATIVE_THRESHOLD = 0.20\n", - "\n", - "# Duration: |response_dur / ref_dur - 1| must be below this.\n", - "# e.g. 0.25 means the student's duration can be off by up to 25% of the reference.\n", - "DURATION_RELATIVE_THRESHOLD = 0.25\n", - "\n", - "# Thresholds that trigger a global tempo comment in the overview.\n", - "GLOBAL_SLOW_THRESHOLD = 1.15 # timing_scale > 1.15 -> \"overall too slow\"\n", - "GLOBAL_FAST_THRESHOLD = 0.85 # timing_scale < 0.85 -> \"overall too fast\"\n", - "\n", - "# Step 0 - make first note start at t = 0.0\n", - "# ------------------------------------------------------------------------------\n", - "def normalize_start_times(notes):\n", - " \"\"\"\n", - " Shift all notes so that the first note starts at t=0.\n", - " \n", - " Args:\n", - " notes: list of note dicts, each with at least a \"start\" key.\n", - " \n", - " Returns:\n", - " A new list of note dicts (copies, not the original objects), with\n", - " every \"start\" value shifted so notes[0][\"start\"] == 0. Returns an\n", - " empty list unchanged if notes is empty.\n", - " \"\"\"\n", - " if not notes:\n", - " return []\n", - " \n", - " first_start = notes[0][\"start\"]\n", - " \n", - " shifted_notes = []\n", - " for note in notes:\n", - " # Create a copy of the note dict with the \"start\" time shifted\n", - " note_copy = {\n", - " \"pitch\": note[\"pitch\"],\n", - " \"start\": note[\"start\"] - first_start,\n", - " \"duration\": note[\"duration\"],\n", - " }\n", - " shifted_notes.append(note_copy)\n", - " \n", - " return shifted_notes" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fc06c1bb", - "metadata": {}, - "outputs": [], - "source": [ - "# Default threshold: notes starting within 50ms are grouped as one chord.\n", - "DEFAULT_CHORD_ONSET_WINDOW = 0.05\n", - "\n", - "# Chord template dictionary.\n", - "# Each entry maps a chord quality name to a set of pitch class intervals,\n", - "# where the root note is normalised to 0.\n", - "# Only the 4 triad types are included, following Muller (2021) Chapter 5.\n", - "CHORD_TEMPLATES = {\n", - " \"major\": set([0, 4, 7]),\n", - " \"minor\": set([0, 3, 7]),\n", - " \"diminished\": set([0, 3, 6]),\n", - " \"augmented\": set([0, 4, 8]),\n", - "}\n", - " \n", - "# Pitch class names for human-readable feedback messages.\n", - "PITCH_CLASS_NAMES = [\"C\", \"C#\", \"D\", \"D#\", \"E\", \"F\",\n", - " \"F#\", \"G\", \"G#\", \"A\", \"A#\", \"B\"]\n", - "\n", - "# helper function to build an event dict from a group of notes\n", - "def make_event(notes_in_group):\n", - " \"\"\"\n", - " Build a single event dict from a group of notes.\n", - " \n", - " Args:\n", - " notes_in_group: list of one or more note dicts.\n", - " \n", - " Returns:\n", - " event dict with keys \"event_type\", \"notes\", \"start\". example:\n", - " {\n", - " \"event_type\": \"note\" or \"chord\" depending on the number of notes in the group\n", - " \"notes\": [ \n", - " { \n", - " \"pitch\": int\n", - " \"start\": float\n", - " \"duration\": float\n", - " },\n", - " ]\n", - " \"event_start\": float,\n", - " \"event_duration\": float\n", - " }\n", - " \"\"\"\n", - " event_type = \"note\" if len(notes_in_group) == 1 else \"chord\"\n", - " return {\n", - " \"event_type\": event_type,\n", - " \"notes\": notes_in_group,\n", - " # use the start time of the first note in the group as the event start\n", - " \"event_start\": notes_in_group[0][\"start\"], \n", - " # use the duration of the first note in the group as the event duration\n", - " \"event_duration\": notes_in_group[0][\"duration\"], \n", - " }\n", - "\n", - "\n", - "# group notes into events based on their start times\n", - "def group_notes_into_events(notes, chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW):\n", - " \"\"\"\n", - " Group a flat list of notes into events. Notes whose start times fall\n", - " within chord_onset_window seconds of each other are placed into the\n", - " same event (i.e. treated as a chord). Notes that are not grouped with\n", - " any other note form a single-note event.\n", - "\n", - " Args:\n", - " notes: list of dicts, each with keys 'pitch', 'start', 'duration'\n", - " chord_onset_window: float, max time difference (seconds) to be grouped\n", - "\n", - " Returns:\n", - " list of dicts, each with keys:\n", - " 'pitches' : set of int MIDI pitch numbers\n", - " 'start' : float, earliest start time in the group\n", - " 'duration' : float, average duration of notes in the group\n", - " \"\"\"\n", - " if len(notes) == 0:\n", - " return []\n", - "\n", - " # Sort notes by start time first\n", - " sorted_notes = sorted(notes, key=lambda n: n[\"start\"])\n", - "\n", - " events = []\n", - " current_group = [sorted_notes[0]]\n", - " group_start = sorted_notes[0][\"start\"]\n", - "\n", - " for note in sorted_notes[1:]:\n", - " # Close enough in time: add to current group (chord)\n", - " if note[\"start\"] - group_start <= chord_onset_window:\n", - " current_group.append(note)\n", - " else:\n", - " # Too far apart: save current chord, start a new group\n", - " event = make_event(current_group)\n", - " events.append(event)\n", - " current_group = [note]\n", - " group_start = note[\"start\"]\n", - "\n", - " # append the last group\n", - " last_event = make_event(current_group)\n", - " events.append(last_event)\n", - "\n", - " return events" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "6dbac7c6", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_note_cost(note1, note2):\n", - " \"\"\"\n", - " Cost of aligning (replacing) one note with another, based on pitch.\n", - " \n", - " cost = 0: pitches are identical (a 'match'). \n", - " cost > 0: different pitches (a 'replacement')\n", - " \n", - " Args:\n", - " note1: dict with keys \"pitch\" (int), \"start\" (float), \"duration\" (float)\n", - " note2: dict with keys \"pitch\" (int), \"start\" (float), \"duration\" (float)\n", - " \n", - " Returns:\n", - " int: cost value >= 0 (lower means more similar pitch)\n", - " \"\"\"\n", - " return int(abs(note1[\"pitch\"] - note2[\"pitch\"]))\n", - "\n", - "\n", - "def compute_event_cost(event1, event2, gap_penalty=DEFAULT_GAP_PENALTY):\n", - " \"\"\"\n", - " Cost of aligning (substituting) one event with another.\n", - " \n", - " Rules:\n", - " - note vs note: absolute pitch difference \n", - " - chord vs chord: Hamming distance on 12-dim pitch class binary vectors\n", - " - note vs chord (type mismatch): return gap_penalty so the aligner\n", - " treats them as unaligned (insertion + deletion is preferred)\n", - " \n", - " For chord vs chord, the Hamming distance counts how many of the 12\n", - " pitch classes differ between the two chords (symmetric difference size).\n", - " \n", - " Args:\n", - " event1: event dict (from group_notes_into_events)\n", - " event2: event dict (from group_notes_into_events)\n", - " gap_penalty: cost of an unaligned event; used as the type-mismatch cost\n", - " \n", - " Returns:\n", - " int: alignment cost >= 0\n", - " \"\"\"\n", - " type1 = event1[\"event_type\"]\n", - " type2 = event2[\"event_type\"]\n", - " \n", - " # Type mismatch: note vs chord or chord vs note.\n", - " # Return gap_penalty so alignment prefers to leave them unmatched.\n", - " if type1 != type2:\n", - " return gap_penalty\n", - " \n", - " # Both are single notes: use pitch difference (same as Phase 1)\n", - " elif type1 == \"note\" and type2 == \"note\":\n", - " return compute_note_cost(event1[\"notes\"][0], event2[\"notes\"][0])\n", - " \n", - " # Both are chords: Hamming distance on 12-dimensional pitch class vectors.\n", - " else:\n", - " vec1 = [0] * 12\n", - " vec2 = [0] * 12\n", - " for note in event1[\"notes\"]:\n", - " vec1[note[\"pitch\"] % 12] = 1\n", - " for note in event2[\"notes\"]:\n", - " vec2[note[\"pitch\"] % 12] = 1\n", - " \n", - " hamming = sum(1 for i in range(12) if vec1[i] != vec2[i])\n", - " return hamming" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "1f756e2f", - "metadata": {}, - "outputs": [], - "source": [ - "def event_alignment_ED(response_events, ref_events, gap_penalty=DEFAULT_GAP_PENALTY):\n", - " \"\"\"\n", - " Align events (notes or chords) using edit distance (ED). \n", - " The ED allows for insertions and deletions, which can be useful for \n", - " evaluating musical practice containing missing/extra notes.\n", - " \n", - " Args:\n", - " response_events: list of event dicts from group_notes_into_events\n", - " ref_events: list of event dicts from group_notes_into_events\n", - " gap_penalty: cost of leaving an event unaligned (insertion/deletion)\n", - " \n", - " Returns:\n", - " operations: list of transformation ops dicts, in order from first event to last:\n", - " {'type': 'match' or 'replacement' or 'missing' or 'extra', \n", - " 'response_idx': int or None, \n", - " 'reference_idx': int or None, \n", - " 'cost': int}\n", - " D: accumulated cost matrix, shape (N+1, M+1)\n", - " \"\"\"\n", - " # the rows of D correspond to response events\n", - " N = len(response_events)\n", - " # the columns of D correspond to reference events\n", - " M = len(ref_events)\n", - "\n", - " # Build the accumulated cost matrix D of size (N+1 x M+1)\n", - " D = np.zeros((N + 1, M + 1), dtype=int)\n", - " \n", - " # Boundary conditions: aligning against an empty sequence means every event\n", - " # is unaligned, so the cost is n (or m) times the gap penalty.\n", - " for n in range(1, N + 1):\n", - " D[n, 0] = n * gap_penalty # n extra response events\n", - " for m in range(1, M + 1):\n", - " D[0, m] = m * gap_penalty # m missing ref events\n", - "\n", - " # Recursion (accumulated cost / score matrix D):\n", - " for n in range(1, N + 1):\n", - " for m in range(1, M + 1):\n", - " replace_cost = compute_event_cost(response_events[n-1], ref_events[m-1], gap_penalty)\n", - " D[n, m] = min(\n", - " D[n-1, m-1] + replace_cost, # diagonal: match or replacement\n", - " D[n-1, m] + gap_penalty, # vertical: extra event response[n-1]\n", - " D[n, m-1] + gap_penalty, # horizontal: missing response for ref[m-1]\n", - " )\n", - "\n", - " # Backtrack and classify each transformation op based on movement direction in D\n", - " operations = []\n", - " n, m = N, M\n", - " while n > 0 or m > 0:\n", - " # Boundary conditions: at the top row, only horizontal moves possible\n", - " if n == 0:\n", - " # Missing response for ref[m-1] (deletion)\n", - " operations.append({\n", - " \"type\": \"missing\",\n", - " \"response_idx\": None,\n", - " \"reference_idx\": m - 1,\n", - " \"cost\": gap_penalty,\n", - " })\n", - " m -= 1\n", - " # At the leftmost column, only vertical moves possible\n", - " elif m == 0:\n", - " # Extra event response[n-1] (insertion)\n", - " operations.append({\n", - " \"type\": \"extra\",\n", - " \"response_idx\": n - 1,\n", - " \"reference_idx\": None,\n", - " \"cost\": gap_penalty,\n", - " })\n", - " n -= 1\n", - " # For all other cases, we can move in any direction (diagonal, vertical, horizontal)\n", - " else:\n", - " replace_cost = compute_event_cost(response_events[n - 1], ref_events[m - 1], gap_penalty)\n", - " diag = D[n - 1, m - 1] + replace_cost # diagonal: match or replacement\n", - " up = D[n - 1, m] + gap_penalty # vertical: extra event response[n-1]\n", - " left = D[n, m - 1] + gap_penalty # horizontal: missing response for ref[m-1]\n", - " min_cost = min(diag, up, left) # find the minimum cost step\n", - "\n", - " # classify the transformation ops based on the minimum cost step\n", - " if min_cost == diag: # Diagonal -> two events are aligned (match or replacement)\n", - " operations.append({\n", - " \"type\": \"match\" if replace_cost == 0 else \"replacement\",\n", - " \"response_idx\": n - 1,\n", - " \"reference_idx\": m - 1,\n", - " \"cost\": replace_cost,\n", - " })\n", - " n, m = n - 1, m - 1\n", - " elif min_cost == up: # Vertical -> response[n-1] is extra (insertion)\n", - " operations.append({\n", - " \"type\": \"extra\",\n", - " \"response_idx\": n - 1,\n", - " \"reference_idx\": None,\n", - " \"cost\": gap_penalty,\n", - " })\n", - " n -= 1\n", - " else: # Horizontal -> response is missing for ref[m-1] (deletion)\n", - " operations.append({\n", - " \"type\": \"missing\",\n", - " \"response_idx\": None,\n", - " \"reference_idx\": m - 1,\n", - " \"cost\": gap_penalty,\n", - " })\n", - " m -= 1\n", - "\n", - " operations.reverse() # Reverse to get ops in order from first note to last\n", - " return operations, D\n", - "\n", - "\n", - "def estimate_global_timing(operations, response_events, ref_events):\n", - " \"\"\"\n", - " Estimate the student's overall tempo relative to the reference, by fitting\n", - " a straight line through the matched note start times:\n", - " response_start ≈ scale * ref_start + offset\n", - " where:\n", - " scale > 1 means the student is playing slower overall\n", - " scale < 1 means the student is playing faster overall\n", - " offset captures any constant time shift\n", - "\n", - " Args:\n", - " operations: list of operation dicts (match/replacement/missing/extra)\n", - " response_events: list of event dicts from response\n", - " ref_events: list of event dicts from reference\n", - "\n", - " Returns:\n", - " scale: float, estimated tempo ratio (1.0 = same speed as reference)\n", - " offset: float (seconds), estimated constant time shift\n", - " \"\"\"\n", - " # Collect (ref_start, response_start) pairs from matched/replaced notes only.\n", - " # Missing/extra notes have no pair, so they cannot contribute to the fit.\n", - " ref_starts = []\n", - " response_starts = []\n", - " for op in operations:\n", - " if op[\"type\"] in (\"match\", \"replacement\"):\n", - " res = op[\"response_idx\"]\n", - " ref = op[\"reference_idx\"]\n", - " ref_starts.append(ref_events[ref][\"event_start\"])\n", - " response_starts.append(response_events[res][\"event_start\"])\n", - " \n", - " # Not enough points for fitting a meaningful line — assume no drift in tempo.\n", - " if len(ref_starts) < 3:\n", - " return 1.0, 0.0\n", - " \n", - " x = np.array(ref_starts, dtype=float)\n", - " y = np.array(response_starts, dtype=float)\n", - " \n", - " # Least-squares line fit: y = scale * x + offset\n", - " scale, offset = np.polyfit(x, y, 1)\n", - " \n", - " return float(scale), float(offset)\n", - "\n", - "\n", - "def estimate_global_duration_scale(operations, response_events, ref_events):\n", - " \"\"\"\n", - " Estimate the student's overall note-length scale relative to the reference,\n", - " by fitting a line through the origin:\n", - " response_duration ≈ duration_scale * ref_duration\n", - " where:\n", - " duration_scale > 1 means notes are held longer overall\n", - " duration_scale < 1 means notes are held shorter overall\n", - "\n", - " Args:\n", - " operations: output of note_alignment_ED()\n", - " response_events: list of student event dicts\n", - " ref_events: list of reference event dicts\n", - "\n", - " Returns:\n", - " duration_scale (float): estimated duration ratio (1.0 = same as reference)\n", - " \"\"\"\n", - " ref_durations = []\n", - " response_durations = []\n", - " for op in operations:\n", - " if op[\"type\"] in (\"match\", \"replacement\"):\n", - " res = op[\"response_idx\"]\n", - " ref = op[\"reference_idx\"]\n", - " ref_durations.append(ref_events[ref][\"event_duration\"])\n", - " response_durations.append(response_events[res][\"event_duration\"])\n", - "\n", - " if len(ref_durations) < 3:\n", - " return 1.0\n", - "\n", - " x = np.array(ref_durations, dtype=float)\n", - " y = np.array(response_durations, dtype=float)\n", - "\n", - " # Least-squares fit through the origin: y = scale * x\n", - " # Closed-form solution: scale = sum(x*y) / sum(x*x)\n", - " duration_scale = float(np.sum(x * y) / np.sum(x * x))\n", - "\n", - " return duration_scale" - ] - }, { "cell_type": "code", - "execution_count": 29, - "id": "bdb915ea", + "execution_count": 36, + "id": "a1800307", "metadata": {}, "outputs": [], "source": [ - "# Chord helper functions\n", - "def get_pitch_class_set(notes):\n", - " \"\"\"\n", - " Convert a list of notes to a set of pitch classes (each pitch mod 12).\n", - "\n", - " Args:\n", - " notes: list of note dicts, each with a \"pitch\" key.\n", - "\n", - " Returns:\n", - " set of ints, each in range [0, 11].\n", - " \"\"\"\n", - " return set(note[\"pitch\"] % 12 for note in notes)\n", - " \n", - " \n", - "def identify_chord_name(notes):\n", - " \"\"\"\n", - " Identify the chord name (e.g. \"C major\", \"A minor\") from a list of notes\n", - " by matching their pitch class set against CHORD_TEMPLATES.\n", - " \n", - " For each candidate root in the pitch class set, normalise all pitch classes\n", - " to start at 0 and check against each template. If no match is found,\n", - " returns \"unknown chord\".\n", - " \n", - " Args:\n", - " notes: list of note dicts, each with a \"pitch\" key.\n", - " \n", - " Returns:\n", - " str: chord name e.g. \"C major\", or \"unknown chord\".\n", - " \"\"\"\n", - " pitch_classes = get_pitch_class_set(notes)\n", - " \n", - " for root_pc in pitch_classes:\n", - " normalised = set((pc - root_pc) % 12 for pc in pitch_classes)\n", - " for chord_type, template in CHORD_TEMPLATES.items():\n", - " if normalised == template:\n", - " root_name = PITCH_CLASS_NAMES[root_pc]\n", - " return root_name + \" \" + chord_type\n", - " \n", - " return \"unknown chord\"\n", - " \n", - " \n", - "def compute_chord_accuracy(ref_notes, res_notes):\n", - " \"\"\"\n", - " Compute the chord accuracy score A from McLeod & Rohit (2022),\n", - " \n", - " A = (C - I + |y|) / (2 * |y|)\n", - " \n", - " where:\n", - " C = |y ∩ y_hat| (correctly played pitch classes)\n", - " I = |y_hat - y| (extra pitch classes played)\n", - " |y| (number of pitch classes in the reference chord)\n", - " \n", - " A = 1.0 means perfectly correct. A = 0.0 means nothing correct and\n", - " many extra notes are played.\n", - " \n", - " Args:\n", - " ref_notes: list of note dicts for the reference chord.\n", - " res_notes: list of note dicts for the response chord.\n", - " \n", - " Returns:\n", - " accuracy: float in [0, 1]\n", - " correct_pitches: sorted list of pitch class ints in both chords\n", - " missing_pitches: sorted list of pitch class ints in ref\n", - " extra_pitches:sorted list of pitch class ints in response\n", - " \"\"\"\n", - " ref_pcs = get_pitch_class_set(ref_notes)\n", - " res_pcs = get_pitch_class_set(res_notes)\n", - " \n", - " correct_pcs = ref_pcs & res_pcs\n", - " missing_pcs = ref_pcs - res_pcs\n", - " extra_pcs = res_pcs - ref_pcs\n", - " \n", - " C = len(correct_pcs)\n", - " I = len(extra_pcs)\n", - " ref_size = len(ref_pcs)\n", - " \n", - " if ref_size == 0:\n", - " accuracy = 0.0\n", - " else:\n", - " accuracy = (C - I + ref_size) / (2.0 * ref_size)\n", - " accuracy = max(0.0, min(1.0, accuracy))\n", - " \n", - " return accuracy, sorted(correct_pcs), sorted(missing_pcs), sorted(extra_pcs)" + "from evaluation_function.compare_MIDI import compare_performance_ED" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "88c61aa5", + "cell_type": "markdown", + "id": "ae192f73", "metadata": {}, - "outputs": [], "source": [ - "def event_level_feedback(operations, response_events, ref_events,\n", - " timing_scale=1.0, timing_offset=0.0,\n", - " duration_scale=1.0,\n", - " timing_relative_threshold=TIMING_RELATIVE_THRESHOLD,\n", - " duration_relative_threshold=DURATION_RELATIVE_THRESHOLD):\n", - " \"\"\"\n", - " Analyse each aligned event pair (or missing/extra event) and return a\n", - " list of event result dicts.\n", - " For single-note events, pitch evaluation uses absolute pitch difference.\n", - " For chord events, pitch evaluation uses the chord accuracy metric A:\n", - " A = (C - I + |y|) / (2 * |y|)\n", - " where C = correctly played pitch classes, I = extra pitch classes played,\n", - " |y| = number of pitch classes in the reference chord.\n", - "\n", - " Args:\n", - " operations: list of op dicts (match/replacement/missing/extra)\n", - " response_events: list of event dicts from response\n", - " ref_events: list of event dicts from reference\n", - " timing_scale: float, estimated tempo ratio (1.0 = same speed as reference)\n", - " timing_offset: float (seconds), estimated constant time shift\n", - " duration_scale: float, estimated overall duration ratio\n", - " timing_relative_threshold: float, relative tolerance for timing correctness\n", - " duration_relative_threshold: float, relative tolerance for duration correctness\n", - "\n", - " Returns:\n", - " event_level_results : list of dicts, each dict contains:\n", - "\n", - " For note events, each dict has:\n", - " \"event_type\" -> \"note\"\n", - " \"reference_index\" -> int (1-based) or None\n", - " \"response_index\" -> int (1-based) or None\n", - " \"operation_type\" -> \"match\", \"replacement\", \"missing\", \"extra\"\n", - " \"pitch_correct\" -> bool\n", - " \"pitch_diff\" -> int (semitones) or None\n", - " \"timing_correct\" -> bool\n", - " \"timing_abs_diff\" -> float (seconds) or None\n", - " \"timing_relative_diff\" -> float or None\n", - " \"duration_correct\" -> bool\n", - " \"duration_abs_diff\" -> float (seconds) or None\n", - " \"duration_relative_diff\" -> float or None\n", - " \n", - " For chord events, each dict has:\n", - " \"event_type\" -> \"chord\"\n", - " \"reference_index\" -> int (1-based) or None\n", - " \"response_index\" -> int (1-based) or None\n", - " \"operation_type\" -> \"match\", \"replacement\", \"missing\", \"extra\"\n", - " \"chord_name_ref\" -> str e.g. \"C major\", or None\n", - " \"chord_name_res\" -> str e.g. \"C minor\", or None\n", - " \"chord_accuracy\" -> float (0 to 1) or None\n", - " \"correct_pitches\" -> list of pitch class ints or None\n", - " \"missing_pitches\" -> list of pitch class ints or None\n", - " \"extra_pitches\" -> list of pitch class ints or None\n", - " \"timing_correct\" -> bool\n", - " \"timing_abs_diff\" -> float (seconds) or None\n", - " \"timing_relative_diff\" -> float or None\n", - " \"duration_correct\" -> bool\n", - " \"duration_abs_diff\" -> float (seconds) or None\n", - " \"duration_relative_diff\" -> float or None\n", - " \"\"\"\n", - " # Compute IOI for each reference note: ioi[m] = ref_notes[m][\"start\"] - ref_notes[m-1][\"start\"]\n", - " # floor at 0.05s to avoid division by zero issues\n", - " ref_ioi = [None] * len(ref_events)\n", - " for m in range(1, len(ref_events)):\n", - " interval = ref_events[m][\"event_start\"] - ref_events[m - 1][\"event_start\"]\n", - " ref_ioi[m] = max(interval, 0.05)\n", - " \n", - " event_level_results = []\n", - "\n", - " for op in operations:\n", - " res_idx = op[\"response_idx\"]\n", - " ref_idx = op[\"reference_idx\"]\n", - " op_type = op[\"type\"]\n", - "\n", - " # Determine event type from whichever side is available\n", - " if ref_idx is not None:\n", - " event_type = ref_events[ref_idx][\"event_type\"]\n", - " else:\n", - " event_type = response_events[res_idx][\"event_type\"]\n", - "\n", - " # Missing/extra notes: no pitch/timing/duration comparison is possible,\n", - " # so all the numeric fields are set to None.\n", - " if op_type in (\"missing\", \"extra\"):\n", - " if event_type == \"note\":\n", - " event_level_results.append({\n", - " \"event_type\": \"note\",\n", - " \"reference_index\": (ref_idx + 1) if ref_idx is not None else None,\n", - " \"response_index\": (res_idx + 1) if res_idx is not None else None,\n", - " \"operation_type\": op_type,\n", - " \"pitch_correct\": False,\n", - " \"pitch_diff\": None,\n", - " \"timing_correct\": False,\n", - " \"timing_abs_diff\": None,\n", - " \"timing_relative_diff\": None,\n", - " \"duration_correct\": False,\n", - " \"duration_abs_diff\": None,\n", - " \"duration_relative_diff\": None,\n", - " })\n", - " else:\n", - " if ref_idx is not None:\n", - " chord_name = identify_chord_name(ref_events[ref_idx][\"notes\"])\n", - " else:\n", - " chord_name = identify_chord_name(\n", - " response_events[res_idx][\"notes\"]\n", - " )\n", - " event_level_results.append({\n", - " \"event_type\": \"chord\",\n", - " \"reference_index\": (ref_idx + 1) if ref_idx is not None else None,\n", - " \"response_index\": (res_idx + 1) if res_idx is not None else None,\n", - " \"operation_type\": op_type,\n", - " \"chord_name_ref\": chord_name if op_type == \"missing\" else None,\n", - " \"chord_name_res\": chord_name if op_type == \"extra\" else None,\n", - " \"chord_accuracy\": None,\n", - " \"correct_pitches\": None,\n", - " \"missing_pitches\": None,\n", - " \"extra_pitches\": None,\n", - " \"timing_correct\": False,\n", - " \"timing_abs_diff\": None,\n", - " \"timing_relative_diff\": None,\n", - " \"duration_correct\": False,\n", - " \"duration_abs_diff\": None,\n", - " \"duration_relative_diff\": None,\n", - " })\n", - " else:\n", - " # Aligned event pair (match or replacement)\n", - " res_event = response_events[res_idx]\n", - " ref_event = ref_events[ref_idx]\n", - "\n", - " # Timing — residual after removing the global tempo trend\n", - " predicted_start = timing_scale * ref_event[\"event_start\"] + timing_offset\n", - " timing_abs_diff = abs(res_event[\"event_start\"] - predicted_start)\n", - " if ref_idx == 0:\n", - " # First note will start at 0, so no difference.\n", - " timing_relative_diff = None\n", - " timing_correct = True\n", - " else:\n", - " ioi = ref_ioi[ref_idx]\n", - " timing_relative_diff = timing_abs_diff / ioi\n", - " timing_correct = (timing_relative_diff <= timing_relative_threshold)\n", - "\n", - " # Duration — residual after removing the global duration-scale trend\n", - " predicted_duration = duration_scale * ref_event[\"event_duration\"]\n", - " duration_abs_diff = res_event[\"event_duration\"] - predicted_duration\n", - " ref_dur = max(ref_event[\"event_duration\"], 0.05) # floor at 0.05s to avoid division by zero issues\n", - " duration_relative_diff = duration_abs_diff / ref_dur\n", - " duration_correct = (abs(duration_relative_diff) <= duration_relative_threshold)\n", - " if event_type == \"note\":\n", - " # For single-note events, pitch correctness is based on absolute pitch difference.\n", - " pitch1 = res_event[\"notes\"][0][\"pitch\"]\n", - " pitch2 = ref_event[\"notes\"][0][\"pitch\"]\n", - " pitch_diff = int(abs(pitch1 - pitch2))\n", - " event_level_results.append({\n", - " \"event_type\": \"note\",\n", - " \"reference_index\": ref_idx + 1,\n", - " \"response_index\": res_idx + 1,\n", - " \"operation_type\": op_type,\n", - " \"pitch_correct\": (pitch_diff == 0),\n", - " \"pitch_diff\": pitch_diff,\n", - " \"timing_correct\": timing_correct,\n", - " \"timing_abs_diff\": timing_abs_diff,\n", - " \"timing_relative_diff\": timing_relative_diff,\n", - " \"duration_correct\": duration_correct,\n", - " \"duration_abs_diff\": duration_abs_diff,\n", - " \"duration_relative_diff\": duration_relative_diff,\n", - " })\n", - " else:\n", - " # For chord events, pitch correctness is based on the chord accuracy metric.\n", - " accuracy, correct_pcs, missing_pcs, extra_pcs = (\n", - " compute_chord_accuracy(\n", - " ref_event[\"notes\"], res_event[\"notes\"]\n", - " )\n", - " )\n", - " event_level_results.append({\n", - " \"event_type\": \"chord\",\n", - " \"reference_index\": ref_idx + 1,\n", - " \"response_index\": res_idx + 1,\n", - " \"operation_type\": op_type,\n", - " \"chord_name_ref\": identify_chord_name(ref_event[\"notes\"]),\n", - " \"chord_name_res\": identify_chord_name(res_event[\"notes\"]),\n", - " \"chord_accuracy\": accuracy,\n", - " \"correct_pitches\": correct_pcs,\n", - " \"missing_pitches\": missing_pcs,\n", - " \"extra_pitches\": extra_pcs,\n", - " \"timing_correct\": timing_correct,\n", - " \"timing_abs_diff\": timing_abs_diff,\n", - " \"timing_relative_diff\": timing_relative_diff,\n", - " \"duration_correct\": duration_correct,\n", - " \"duration_abs_diff\": duration_abs_diff,\n", - " \"duration_relative_diff\": duration_relative_diff,\n", - " })\n", - "\n", - " return event_level_results\n", - "\n", - "\n", - "\n", - "def compute_stats(event_level_results, ref_events, timing_scale=1.0,\n", - " timing_offset=0.0, duration_scale=1.0):\n", - " \"\"\"\n", - " Compute summary counts and correctness booleans from event-level feedback.\n", - "\n", - " Args:\n", - " event_level_results: list of dicts, output of event_level_feedback()\n", - " ref_events: list of reference event dicts\n", - " timing_scale: float, from estimate_global_timing()\n", - " timing_offset: float, from estimate_global_timing()\n", - " duration_scale: float, from estimate_global_duration_scale()\n", - "\n", - " Returns:\n", - " stats: dict with keys:\n", - " \"pitch_all_correct\" -> bool, True if all note pitches are\n", - " correct AND all chord accuracies are 1.0\n", - " \"timing_all_correct\" -> bool\n", - " \"duration_all_correct\" -> bool\n", - " \"total_notes_in_reference\" -> int\n", - " \"total_notes_missing\" -> int\n", - " \"total_notes_extra\" -> int\n", - " \"total_notes_wrong_pitch\" -> int\n", - " \"total_notes_wrong_timing\" -> int\n", - " \"total_notes_wrong_duration\" -> int\n", - " \"total_notes_correct\" -> int\n", - " \"timing_scale\" -> float\n", - " \"timing_offset\" -> float\n", - " \"duration_scale\" -> float\n", - " \"total_chords_in_reference\" -> int\n", - " \"total_chords_missing\" -> int\n", - " \"total_chords_extra\" -> int\n", - " \"total_chords_correct\" -> int (accuracy == 1.0)\n", - " \"total_chords_imperfect\" -> int (0.0 < accuracy < 1.0)\n", - " \"total_chords_wrong\" -> int (accuracy == 0.0)\n", - " \"\"\"\n", - " note_events = [n for n in event_level_results if n[\"event_type\"] == \"note\"]\n", - " chord_events = [ch for ch in event_level_results if ch[\"event_type\"] == \"chord\"]\n", - " \n", - " ref_note_count = sum(1 for n in ref_events if n[\"event_type\"] == \"note\")\n", - " ref_chord_count = sum(1 for ch in ref_events if ch[\"event_type\"] == \"chord\")\n", - " \n", - " paired_notes = [\n", - " n for n in note_events\n", - " if n[\"operation_type\"] in (\"match\", \"replacement\")\n", - " ]\n", - " paired_chords = [\n", - " ch for ch in chord_events\n", - " if ch[\"operation_type\"] in (\"match\", \"replacement\")\n", - " ]\n", - " all_paired = [\n", - " e for e in event_level_results\n", - " if e[\"operation_type\"] in (\"match\", \"replacement\")\n", - " ]\n", - "\n", - " stats = {\n", - " \"pitch_all_correct\": (\n", - " all(n[\"pitch_correct\"] for n in paired_notes)\n", - " and all(\n", - " ch[\"chord_accuracy\"] == 1.0\n", - " for ch in paired_chords\n", - " if ch[\"chord_accuracy\"] is not None\n", - " )\n", - " ),\n", - " \"timing_all_correct\": all(n[\"timing_correct\"] for n in all_paired),\n", - " \"duration_all_correct\": all(n[\"duration_correct\"] for n in all_paired),\n", - " \"total_notes_in_reference\": ref_note_count,\n", - " \"total_notes_missing\": sum(1 for n in note_events if n[\"operation_type\"] == \"missing\"),\n", - " \"total_notes_extra\": sum(1 for n in note_events if n[\"operation_type\"] == \"extra\"),\n", - " \"total_notes_wrong_pitch\": sum(1 for n in paired_notes if not n[\"pitch_correct\"]),\n", - " \"total_notes_wrong_timing\": sum(1 for n in paired_notes if not n[\"timing_correct\"]),\n", - " \"total_notes_wrong_duration\": sum(1 for n in paired_notes if not n[\"duration_correct\"]),\n", - " \"total_notes_correct\": sum(1 for n in paired_notes\n", - " if n[\"pitch_correct\"] and n[\"timing_correct\"] and n[\"duration_correct\"]\n", - " ),\n", - " \"timing_scale\": timing_scale,\n", - " \"timing_offset\": timing_offset,\n", - " \"duration_scale\": duration_scale,\n", - " \"total_chords_in_reference\": ref_chord_count,\n", - " \"total_chords_missing\": sum(1 for ch in chord_events if ch[\"operation_type\"] == \"missing\"),\n", - " \"total_chords_extra\": sum(1 for ch in chord_events if ch[\"operation_type\"] == \"extra\"),\n", - " \"total_chords_correct\": sum(\n", - " 1 for ch in paired_chords\n", - " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] == 1.0\n", - " ),\n", - " \"total_chords_imperfect\": sum(\n", - " 1 for ch in paired_chords\n", - " if ch[\"chord_accuracy\"] is not None\n", - " and ch[\"chord_accuracy\"] > 0.0\n", - " and ch[\"chord_accuracy\"] < 1.0\n", - " ),\n", - " \"total_chords_wrong\": sum(\n", - " 1 for ch in paired_chords\n", - " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] == 0.0\n", - " ),\n", - " }\n", - " return stats\n", - "\n", - "\n", - "def generate_feedback_message(event_details, response_events, ref_events, stats,\n", - " global_slow_threshold=GLOBAL_SLOW_THRESHOLD,\n", - " global_fast_threshold=GLOBAL_FAST_THRESHOLD):\n", - " \"\"\"\n", - " Generate human-readable feedback messages for the student.\n", - "\n", - " Part 1 - Overview: summary of timing trend, duration trend, and total counts\n", - " of each error type (pitch / missing / extra).\n", - " Part 2 - Note Detail: pitch, timing, duration errors per note\n", - " Part 3 - Chord Detail: errors per chord\n", - "\n", - " Args:\n", - " event_details: list of dicts, output of event_level_feedback()\n", - " response_events: list of event dicts from group_notes_into_events\n", - " ref_events: list of event dicts from group_notes_into_events\n", - " stats: dict, output of compute_stats()\n", - " global_slow_threshold: timing_scale above this triggers \"too slow\" message\n", - " global_fast_threshold: timing_scale below this triggers \"too fast\" message\n", - "\n", - " Returns:\n", - " feedback_message (str)\n", - " \"\"\"\n", - " note_events = [n for n in event_details if n[\"event_type\"] == \"note\"]\n", - " chord_events = [ch for ch in event_details if ch[\"event_type\"] == \"chord\"]\n", - "\n", - " paired_notes = [\n", - " n for n in note_events\n", - " if n[\"operation_type\"] in (\"match\", \"replacement\")\n", - " ]\n", - " paired_chords = [\n", - " ch for ch in chord_events\n", - " if ch[\"operation_type\"] in (\"match\", \"replacement\")\n", - " ]\n", - "\n", - " timing_scale = stats[\"timing_scale\"]\n", - " timing_offset = stats[\"timing_offset\"]\n", - " duration_scale = stats[\"duration_scale\"]\n", - "\n", - " overview_messages = []\n", - " note_detail_messages = []\n", - " chord_detail_messages = []\n", - "\n", - " # ---------- Part 1: Overview ----------\n", - " # Tempo: acceptable / too slow / too fast ---\n", - " timing_pct = abs(timing_scale - 1.0) * 100\n", - " duration_pct = abs(duration_scale - 1.0) * 100\n", - " timing_direction = \"behind\" if timing_scale > 1.0 else \"ahead of\"\n", - " duration_direction = \"longer\" if duration_scale > 1.0 else \"shorter\"\n", - "\n", - " if timing_scale > global_slow_threshold:\n", - " overview_messages.append(\n", - " f\"Overall, your tempo is slower than the reference \"\n", - " f\"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", - " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). \"\n", - " f\"No worries! You will get better when you practice more to get more familiar with it!\"\n", - " )\n", - " elif timing_scale < global_fast_threshold:\n", - " overview_messages.append(\n", - " f\"Overall, your tempo is faster than the reference \"\n", - " f\"(timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", - " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference). \"\n", - " f\"Don't rush even if you are confident in your performance.\" \n", - " f\"Slow down and give each note its full value.\"\n", - " )\n", - " else:\n", - " overview_messages.append(\n", - " f\"Timing: your overall tempo is within an acceptable range. Good job! \"\n", - " f\"The timing is about {timing_pct:.0f}% {timing_direction} the reference in general while \"\n", - " f\"notes are held about {duration_pct:.0f}% {duration_direction} than the reference.\"\n", - " )\n", - "\n", - " # Wrong pitch notes counts\n", - " if stats[\"total_notes_wrong_pitch\"] > 0:\n", - " s = \"is\" if stats[\"total_notes_wrong_pitch\"] == 1 else \"are\"\n", - " note_word = \"note\" if stats[\"total_notes_wrong_pitch\"] == 1 else \"notes\"\n", - " overview_messages.append(\n", - " f\"There {s} {stats['total_notes_wrong_pitch']} {note_word} played with the wrong pitch.\"\n", - " )\n", - " else:\n", - " overview_messages.append(\"There are no pitch errors. Well done!\")\n", - " # Missing notes counts\n", - " if stats[\"total_notes_missing\"] > 0:\n", - " s = \"is\" if stats[\"total_notes_missing\"] == 1 else \"are\"\n", - " note_word = \"note\" if stats[\"total_notes_missing\"] == 1 else \"notes\"\n", - " overview_messages.append(\n", - " f\"There {s} {stats['total_notes_missing']} {note_word} you missed from the reference.\"\n", - " )\n", - " else:\n", - " overview_messages.append(\"There are no missing notes. Great!\")\n", - " # Extra notes counts\n", - " if stats[\"total_notes_extra\"] > 0:\n", - " s = \"is\" if stats[\"total_notes_extra\"] == 1 else \"are\"\n", - " note_word = \"note\" if stats[\"total_notes_extra\"] == 1 else \"notes\"\n", - " overview_messages.append(\n", - " f\"There {s} {stats['total_notes_extra']} extra {note_word} played during practice. \"\n", - " f\"You may need to adjust your fingering or hand position to avoid extra notes.\"\n", - " )\n", - " else:\n", - " overview_messages.append(\"There are no extra notes. Good job!\")\n", - " # Chord errors counts\n", - " if stats[\"total_chords_in_reference\"] > 0:\n", - " total = stats[\"total_chords_in_reference\"]\n", - " correct = stats[\"total_chords_correct\"]\n", - " imperfect = stats[\"total_chords_imperfect\"]\n", - " wrong = stats[\"total_chords_wrong\"]\n", - " overview_messages.append(\n", - " f\"Chords: {correct}/{total} correct, \"\n", - " f\"{imperfect}/{total} imperfect (some notes missing or extra), \"\n", - " f\"{wrong}/{total} completely wrong.\"\n", - " )\n", - " if stats[\"total_chords_missing\"] > 0:\n", - " c_word = \"chord\" if stats[\"total_chords_missing\"] == 1 else \"chords\"\n", - " overview_messages.append(\n", - " f\"{stats['total_chords_missing']} {c_word} missed.\"\n", - " )\n", - " if stats[\"total_chords_extra\"] > 0:\n", - " c_word = \"chord\" if stats[\"total_chords_extra\"] == 1 else \"chords\"\n", - " overview_messages.append(\n", - " f\"{stats['total_chords_extra']} extra {c_word} played.\"\n", - " )\n", - "\n", - " # ---------- Part 2: Note Detail ----------\n", - " # Missing / extra notes\n", - " for n in note_events:\n", - " if n[\"operation_type\"] == \"missing\":\n", - " ref_zero_based = n[\"reference_index\"] - 1\n", - " pitch = ref_events[ref_zero_based][\"notes\"][0][\"pitch\"]\n", - " note_detail_messages.append(\n", - " f\"Note {n['reference_index']} (pitch {pitch}) is missing in your performance.\"\n", - " )\n", - " elif n[\"operation_type\"] == \"extra\":\n", - " res_zero_based = n[\"response_index\"] - 1\n", - " extra = response_events[res_zero_based][\"notes\"][0][\"pitch\"]\n", - " note_detail_messages.append(\n", - " f\"Extra note played: pitch {extra} \"\n", - " f\"at t={response_events[res_zero_based]['event_start']:.2f}s \")\n", - " # Pitch errors\n", - " for n in paired_notes:\n", - " if not n[\"pitch_correct\"]:\n", - " ref_zero_based = n[\"reference_index\"] - 1\n", - " res_zero_based = n[\"response_index\"] - 1\n", - " ref_p = ref_events[ref_zero_based][\"notes\"][0][\"pitch\"]\n", - " res_p = response_events[res_zero_based][\"notes\"][0][\"pitch\"]\n", - " note_detail_messages.append(\n", - " f\"Note {n['reference_index']}: wrong pitch — \"\n", - " f\"expected {ref_p}, played {res_p} \"\n", - " f\"({n['pitch_diff']} semitone(s) off).\"\n", - " )\n", - " # Local timing errors - these are residuals after removing the global timing trend\n", - " for n in paired_notes:\n", - " if not n[\"timing_correct\"]:\n", - " note_detail_messages.append(\n", - " f\"Note {n['reference_index']}: timing is off by \"\n", - " f\"{n['timing_abs_diff']:.2f}s \"\n", - " f\"({n['timing_relative_diff'] * 100:.0f}% of the expected note interval), \"\n", - " f\"after accounting for the overall tempo trend.\"\n", - " )\n", - " # Local duration errors — these are residuals after removing the global duration trend\n", - " for n in paired_notes:\n", - " if not n[\"duration_correct\"]:\n", - " direction = \"longer\" if n[\"duration_abs_diff\"] > 0 else \"shorter\"\n", - " duration_pct_err = abs(n[\"duration_relative_diff\"]) * 100\n", - " note_detail_messages.append(\n", - " f\"Note {n['reference_index']}: duration is \"\n", - " f\"{abs(n['duration_abs_diff']):.2f}s {direction} than the reference \"\n", - " f\"({duration_pct_err:.0f}% off) after accounting for the overall duration trend.\"\n", - " )\n", - "\n", - " # ---------- Part 3: Chord Detail ----------\n", - " # Missing / extra chords\n", - " for ch in chord_events:\n", - " if ch[\"operation_type\"] == \"missing\":\n", - " chord_detail_messages.append(\n", - " f\"Chord {ch['reference_index']} ({ch['chord_name_ref']}) \"\n", - " f\"is missing in your performance.\"\n", - " )\n", - " elif ch[\"operation_type\"] == \"extra\":\n", - " chord_detail_messages.append(\n", - " f\"Extra chord played: {ch['chord_name_res']} \"\n", - " f\"at event position {ch['response_index']}.\"\n", - " )\n", - " # Chord accuracy errors\n", - " for ch in paired_chords:\n", - " if ch[\"chord_accuracy\"] is not None and ch[\"chord_accuracy\"] < 1.0:\n", - " accuracy_pct = round(ch[\"chord_accuracy\"] * 100)\n", - " message = (\n", - " f\"Chord {ch['reference_index']} \"\n", - " f\"(expected {ch['chord_name_ref']}, you played {ch['chord_name_res']}): \"\n", - " f\"{accuracy_pct}% accurate. \"\n", - " )\n", - " if ch[\"missing_pitches\"]:\n", - " missing_names = [PITCH_CLASS_NAMES[pc] for pc in ch[\"missing_pitches\"]]\n", - " message = message + \"Missing note(s): \" + \", \".join(missing_names) + \". \"\n", - " if ch[\"extra_pitches\"]:\n", - " extra_names = [PITCH_CLASS_NAMES[pc] for pc in ch[\"extra_pitches\"]]\n", - " message = message + \"Extra note(s) played: \" + \", \".join(extra_names) + \".\"\n", - " chord_detail_messages.append(message)\n", - " # Local timing errors for chords\n", - " for ch in paired_chords:\n", - " if not ch[\"timing_correct\"] and ch[\"timing_relative_diff\"] is not None:\n", - " chord_detail_messages.append(\n", - " f\"Chord {ch['reference_index']}: timing is off by \"\n", - " f\"{ch['timing_abs_diff']:.2f}s \"\n", - " f\"({ch['timing_relative_diff'] * 100:.0f}% of the expected interval).\"\n", - " )\n", - "\n", - " all_messages = [\"Overview: \"] + overview_messages\n", - "\n", - " if note_detail_messages:\n", - " all_messages = all_messages + [\"\", \"Note Detail:\"] + note_detail_messages\n", - " else:\n", - " all_messages = all_messages + [\"\", \"All notes(melody part) played correctly!\"]\n", - " \n", - " if stats[\"total_chords_in_reference\"] > 0:\n", - " if chord_detail_messages:\n", - " all_messages = (\n", - " all_messages + [\"\", \"Chord Detail:\"] + chord_detail_messages\n", - " )\n", - " else:\n", - " all_messages = all_messages + [\"\", \"Great performance! No further issues on chords found.\"]\n", - " \n", - " return \"\\n\".join(all_messages)\n", - "\n", - "\n", - "# FeedbackResult class\n", - "# ------------------------------------------------------------------------------\n", - "class FeedbackResult:\n", - " \"\"\"\n", - " Container for all outputs of compare_performance_ED().\n", - " Using a class (instead of returning a tuple) makes unit tests much clearer:\n", - " result = compare_performance_ED(response, reference)\n", - " assert result.is_correct == False\n", - " assert result.stats[\"total_notes_missing\"] == 1\n", - " assert \"missing\" in result.feedback_message\n", - "\n", - " Attributes\n", - " ----------\n", - " is_correct : bool\n", - " True only if all notes and chords are perfectly correct.\n", - " stats : dict\n", - " Aggregate counts. See compute_stats() for all keys.\n", - " event_details : list of dicts\n", - " Per-event analysis, one dict per alignment operation.\n", - " Each dict has an \"event_type\" key (\"note\" or \"chord\"),\n", - " plus type-specific fields. See event_level_feedback() for details.\n", - " feedback_message : str\n", - " Human-readable feedback string, ready to display to the student.\n", - " operations : list of dicts\n", - " Raw alignment operations from event_alignment_ED().\n", - " Kept here so visualisation helpers (plot_cost_matrix etc.) can use them.\n", - " D : numpy.ndarray\n", - " Accumulated cost matrix from the alignment step.\n", - " \"\"\"\n", - "\n", - " def __init__(self, is_correct, stats, event_details,\n", - " feedback_message, operations, D):\n", - " self.is_correct = is_correct\n", - " self.stats = stats\n", - " self.event_details = event_details\n", - " self.feedback_message = feedback_message\n", - " self.operations = operations\n", - " self.D = D\n", - " \n", - " def __repr__(self):\n", - " return (\n", - " \"FeedbackResult(is_correct=\" + str(self.is_correct) + \", \"\n", - " \"stats=\" + str(self.stats) + \")\"\n", - " )\n", - "\n", - "\n", - "def compare_performance_ED(responseMIDI, refMIDI,\n", - " gap_penalty=DEFAULT_GAP_PENALTY,\n", - " timing_relative_threshold=TIMING_RELATIVE_THRESHOLD,\n", - " duration_relative_threshold=DURATION_RELATIVE_THRESHOLD,\n", - " global_slow_threshold=GLOBAL_SLOW_THRESHOLD,\n", - " global_fast_threshold=GLOBAL_FAST_THRESHOLD,\n", - " chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW):\n", - " \"\"\"\n", - " Full pipeline: normalisation -> grouping -> alignment -> global trends\n", - " -> event-level evaluation -> summary statistics -> feedback.\n", - " \n", - " Args:\n", - " responseMIDI: student MIDI dict with key \"notes\"\n", - " refMIDI: reference MIDI dict with key \"notes\"\n", - " gap_penalty: cost of an unaligned event\n", - " timing_relative_threshold: see event_level_feedback()\n", - " duration_relative_threshold: see event_level_feedback()\n", - " global_slow_threshold: see generate_feedback_message()\n", - " global_fast_threshold: see generate_feedback_message()\n", - " chord_onset_window: float (seconds), notes within this window are\n", - " grouped into a chord. Default 0.050 (50ms). Teacher-configurable.\n", - " \n", - " Returns:\n", - " FeedbackResult object containing all analysis results\n", - " \"\"\"\n", - " # Step 0: Normalise start times\n", - " response_notes = normalize_start_times(responseMIDI[\"notes\"])\n", - " ref_notes = normalize_start_times(refMIDI[\"notes\"])\n", - " # Group notes into events (single notes or chords)\n", - " response_events = group_notes_into_events(response_notes, chord_onset_window)\n", - " ref_events = group_notes_into_events(ref_notes, chord_onset_window)\n", - "\n", - " # Step 1: Align events using edit distance\n", - " operations, D = event_alignment_ED(response_events, ref_events, gap_penalty)\n", - "\n", - " # Step 2: Estimate the overall tempo trend\n", - " timing_scale, timing_offset = estimate_global_timing(\n", - " operations, response_events, ref_events\n", - " )\n", - " duration_scale = estimate_global_duration_scale(\n", - " operations, response_events, ref_events\n", - " )\n", - "\n", - " # Step 3: Event-level evaluation\n", - " event_details = event_level_feedback(\n", - " operations, response_events, ref_events,\n", - " timing_scale=timing_scale,\n", - " timing_offset=timing_offset,\n", - " duration_scale=duration_scale,\n", - " timing_relative_threshold=timing_relative_threshold,\n", - " duration_relative_threshold=duration_relative_threshold,\n", - " )\n", - "\n", - " # Step 4: Compute summary statistics\n", - " stats = compute_stats(\n", - " event_details, ref_events,\n", - " timing_scale=timing_scale,\n", - " timing_offset=timing_offset,\n", - " duration_scale=duration_scale,\n", - " )\n", - "\n", - " # Step 5: Generate human-readable feedback\n", - " feedback_message = generate_feedback_message(\n", - " event_details, response_events, ref_events, stats,\n", - " global_slow_threshold=global_slow_threshold,\n", - " global_fast_threshold=global_fast_threshold,\n", - " )\n", - "\n", - " # Step 6: Overall pass/fail judgement\n", - " is_correct = (\n", - " stats[\"total_notes_missing\"] == 0\n", - " and stats[\"total_notes_extra\"] == 0\n", - " and stats[\"total_chords_missing\"] == 0\n", - " and stats[\"total_chords_extra\"] == 0\n", - " and stats[\"pitch_all_correct\"]\n", - " and stats[\"timing_all_correct\"]\n", - " and stats[\"duration_all_correct\"]\n", - " )\n", + "chord accuracy metric: https://arxiv.org/pdf/2201.05244\n", "\n", - " return FeedbackResult(\n", - " is_correct=is_correct,\n", - " stats=stats,\n", - " event_details=event_details,\n", - " feedback_message=feedback_message,\n", - " operations=operations,\n", - " D=D,\n", - " )" + "consider chord as unordered set of pitch classes\n", + "https://www.researchgate.net/profile/Pierre-Hanna/publication/265984596_A_Survey_Of_Chord_Distances_With_Comparison_For_Chord_Analysis/links/555c4c7808aec5ac2232b158/A-Survey-Of-Chord-Distances-With-Comparison-For-Chord-Analysis.pdf" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 37, "id": "5317bdff", "metadata": {}, "outputs": [ @@ -1267,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 38, "id": "7b90ed72", "metadata": {}, "outputs": [ @@ -1303,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 39, "id": "81edb080", "metadata": {}, "outputs": [ @@ -1314,7 +145,7 @@ "Purpose: Two chords: first is perfect C major, second is imperfect A minor (missing E). total_chords_correct = 1, total_chords_imperfect = 1.\n", "\n", "is_correct: False\n", - "stats: {'pitch_all_correct': False, 'timing_all_correct': True, 'duration_all_correct': True, 'total_notes_in_reference': 0, 'total_notes_missing': 0, 'total_notes_extra': 0, 'total_notes_wrong_pitch': 0, 'total_notes_wrong_timing': 0, 'total_notes_wrong_duration': 0, 'total_notes_correct': 0, 'timing_scale': 1.0, 'timing_offset': 0.0, 'duration_scale': 1.0, 'total_chords_in_reference': 2, 'total_chords_missing': 0, 'total_chords_extra': 0, 'total_chords_correct': 1, 'total_chords_imperfect': 1, 'total_chords_wrong': 0}\n", + "stats: {'pitch_all_aligned_correct': False, 'timing_all_correct': True, 'duration_all_correct': True, 'total_notes_in_reference': 0, 'total_notes_missing': 0, 'total_notes_extra': 0, 'total_notes_wrong_pitch': 0, 'total_notes_wrong_timing': 0, 'total_notes_wrong_duration': 0, 'total_notes_correct': 0, 'timing_scale': 1.0, 'timing_offset': 0.0, 'duration_scale': 1.0, 'total_chords_in_reference': 2, 'total_chords_missing': 0, 'total_chords_extra': 0, 'total_chords_correct': 1, 'total_chords_imperfect': 1, 'total_chords_wrong': 0}\n", "\n", "Overview: \n", "Timing: your overall tempo is within an acceptable range. Good job! The timing is about 0% ahead of the reference in general while notes are held about 0% shorter than the reference.\n", @@ -1348,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 41, "id": "83a5b636", "metadata": {}, "outputs": [ @@ -1377,7 +208,7 @@ " print(\n", " f\"[{case['case_id']:40s}] \"\n", " f\"is_correct={result.is_correct!s:5} \"\n", - " f\"pitch_all={s['pitch_all_correct']!s:5} \"\n", + " f\"pitch_all={s['pitch_all_aligned_correct']!s:5} \"\n", " f\"timing_all={s['timing_all_correct']!s:5} \"\n", " f\"duration_all={s['duration_all_correct']!s:5} \"\n", " f\"notes_ref={s['total_notes_in_reference']} \"\n", From 77431a0c0cd195d7853c6fe9ad3616bbbff54244 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Sat, 27 Jun 2026 21:06:03 +0100 Subject: [PATCH 6/8] add tests for chord comparison --- evaluation_function/evaluation_test.py | 424 +++++++++++++++++++++---- 1 file changed, 356 insertions(+), 68 deletions(-) diff --git a/evaluation_function/evaluation_test.py b/evaluation_function/evaluation_test.py index 73f5eaa..7094391 100755 --- a/evaluation_function/evaluation_test.py +++ b/evaluation_function/evaluation_test.py @@ -8,34 +8,45 @@ Sections -------- -1. Helper: make_midi -2. Tests for normalize_start_times -3. Tests for note_alignment_ED -4. Tests for estimate_global_timing and estimate_global_duration_scale -5. Tests for note_level_feedback and compute_stats -6. Tests for evaluation_function (Lambda Feedback integration) -7. Tests for parameter overrides +0. Helper: make_midi +1. Tests for helper functions get_pitch_class_set, identify_chord_name, compute_chord_accuracy +2. Tests for normalize_start_times +3. Tests for group_notes_into_events +4. Tests for compute_note_cost and compute_event_cost +5. Tests for event_alignment_ED (covers note-only, chord-only cases and mixed cases) +6. Tests for estimate_global_timing and estimate_global_duration_scale +7. Tests for event_level_feedback and compute_stats (covers note-only and chord-only cases) +8. Tests for evaluation_function (Lambda Feedback integration) +9. Tests for parameter overrides """ import unittest -import json from .compare_MIDI import ( normalize_start_times, - compute_cost, - note_alignment_ED, + group_notes_into_events, + make_event, + compute_note_cost, + compute_event_cost, + get_pitch_class_set, + identify_chord_name, + compute_chord_accuracy, + event_alignment_ED, estimate_global_timing, estimate_global_duration_scale, + event_level_feedback, + compute_stats, compare_performance_ED, DEFAULT_GAP_PENALTY, TIMING_RELATIVE_THRESHOLD, DURATION_RELATIVE_THRESHOLD, GLOBAL_SLOW_THRESHOLD, GLOBAL_FAST_THRESHOLD, + DEFAULT_CHORD_ONSET_WINDOW, ) from .evaluation import evaluation_function -# 1. Helper: make MIDI notes for testing +# 0. Helper: create a minimal MIDI dictionary for testing # ------------------------------------------------------------------------------ def make_midi(pitches, starts, durations): notes = [] @@ -48,9 +59,64 @@ def make_midi(pitches, starts, durations): return {"notes": notes} +# 1. Tests for helper functions get_pitch_class_set, identify_chord_name, compute_chord_accuracy +# ------------------------------------------------------------------------------ +class TestChordHelpers(unittest.TestCase): + + def test_get_pitch_class_set_single_octave(self): + notes = [{"pitch": 60}, {"pitch": 64}, {"pitch": 67}] + result = get_pitch_class_set(notes) + assert result == {0, 4, 7} + + def test_get_pitch_class_set_cross_octave(self): + # C4=60 and C5=72 both map to pitch class 0 + notes = [{"pitch": 60}, {"pitch": 72}] + result = get_pitch_class_set(notes) + assert result == {0} + + def test_identify_c_major(self): + notes = [ + {"pitch": 60}, # C + {"pitch": 64}, # E + {"pitch": 67}, # G + ] + assert identify_chord_name(notes) == "C major" + + def test_identify_unknown_chord(self): + # A random cluster that does not match any template + notes = [{"pitch": 60}, {"pitch": 61}, {"pitch": 62}] + assert identify_chord_name(notes) == "unknown chord" + + def test_perfect_accuracy_is_one(self): + ref = [{"pitch": 60}, {"pitch": 64}, {"pitch": 67}] + res = [{"pitch": 60}, {"pitch": 64}, {"pitch": 67}] + accuracy, correct, missing, extra = compute_chord_accuracy(ref, res) + assert accuracy == 1.0 + assert missing == [] + assert extra == [] + + def test_completely_wrong_notes(self): + # Response has no overlap with reference + ref = [{"pitch": 60}, {"pitch": 64}, {"pitch": 67}] + res = [{"pitch": 61}, {"pitch": 65}, {"pitch": 69}] + accuracy, correct, missing, extra = compute_chord_accuracy(ref, res) + assert accuracy == 0.0 + assert len(correct) == 0 + + def test_partial_match_returns_score_between_0_and_1(self): + # C major ref (0,4,7) vs C-minor response (0,3,7) -- one note off + ref = [{"pitch": 60}, {"pitch": 64}, {"pitch": 67}] + res = [{"pitch": 60}, {"pitch": 63}, {"pitch": 67}] + accuracy, correct, missing, extra = compute_chord_accuracy(ref, res) + assert 0.0 < accuracy < 1.0 + assert len(correct) == 2 + assert 4 in missing # pitch class 4 (E) is missing + assert 3 in extra # pitch class 3 (D#) is extra + + # 2. Tests for normalize_start_times # ------------------------------------------------------------------------------ -class TestNormalizeStartTimes: +class TestNormalizeStartTimes(unittest.TestCase): def test_first_note_starts_at_zero(self): notes = make_midi([60, 62], [1.0, 1.5], [0.5, 0.5])["notes"] @@ -68,124 +134,284 @@ def test_pitch_and_duration_unchanged(self): assert result[0]["pitch"] == 64 assert result[0]["duration"] == 0.8 + def test_already_starts_at_zero_unchanged(self): + notes = make_midi([60, 62], [0.0, 0.5], [0.4, 0.4])["notes"] + result = normalize_start_times(notes) + assert result[0]["start"] == 0.0 + assert abs(result[1]["start"] - 0.5) < 0.0001 + -# 3. Tests for note_alignment_ED +# 3. Tests for group_notes_into_events # ------------------------------------------------------------------------------ -class TestNoteAlignmentED: +class TestGroupNotesIntoEvents(unittest.TestCase): + + def test_single_notes_not_grouped(self): + # Notes 0.5s apart -- should form separate events + notes = make_midi([60, 62, 64], [0.0, 0.5, 1.0], [0.4, 0.4, 0.4])["notes"] + events = group_notes_into_events(notes) + assert len(events) == 3 + for event in events: + assert len(event["notes"]) == 1 + + def test_three_note_chord_grouped_correctly(self): + # C major: all three notes at the same start time + notes = make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5])["notes"] + events = group_notes_into_events(notes) + assert len(events) == 1 + assert events[0]["event_type"] == "chord" + assert len(events[0]["notes"]) == 3 + + def test_notes_within_window_grouped(self): + # Notes 30ms apart -- within the default 50ms window -> one chord event + notes = make_midi([60, 64], [0.00, 0.03], [0.5, 0.5])["notes"] + events = group_notes_into_events(notes, chord_onset_window=0.05) + assert len(events) == 1 + assert events[0]["event_type"] == "chord" + assert len(events[0]["notes"]) == 2 + + def test_notes_outside_window_not_grouped(self): + # Notes 100ms apart -- outside the default 50ms window + notes = make_midi([60, 64], [0.00, 0.10], [0.5, 0.5])["notes"] + events = group_notes_into_events(notes, chord_onset_window=0.05) + assert len(events) == 2 + for event in events: + assert event["event_type"] == "note" + assert len(event["notes"]) == 1 + + def test_event_start_equals_first_note_start(self): + notes = make_midi([60, 64], [0.01, 0.02], [0.5, 0.5])["notes"] + events = group_notes_into_events(notes, chord_onset_window=0.05) + assert abs(events[0]["event_start"] - 0.010) < 0.0001 + + def test_event_duration_equals_max_note_duration(self): + notes = make_midi([60, 64], [0.0, 0.0], [0.4, 0.6])["notes"] + events = group_notes_into_events(notes) + assert abs(events[0]["event_duration"] - 0.6) < 0.0001 + + def test_custom_window_zero_means_no_grouping(self): + notes = make_midi([60, 64], [0.0, 0.001], [0.5, 0.5])["notes"] + events = group_notes_into_events(notes, chord_onset_window=0.0) + assert len(events) == 2 + + +# 4. Tests for compute_note_cost and compute_event_cost +# ------------------------------------------------------------------------------ +class TestCostComputations(unittest.TestCase): + + def test_note_vs_note_different_pitch(self): + e1 = make_event(make_midi([60], [0.0], [0.5])["notes"]) + e2 = make_event(make_midi([65], [0.0], [0.5])["notes"]) + assert compute_event_cost(e1, e2) == 5 + + def test_chord_vs_chord_one_note_different(self): + # C major (0,4,7) vs C-minor (0,3,7): one pitch differs -> Hamming = 2 + e1 = make_event(make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5])["notes"]) + e2 = make_event(make_midi([60, 63, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5])["notes"]) + cost = compute_event_cost(e1, e2) + assert cost == 2 + + def test_type_mismatch_returns_gap_penalty(self): + # note vs chord -- should return the gap_penalty regardless of pitches + note_event = make_event(make_midi([60], [0.0], [0.5])["notes"]) + chord_event = make_event(make_midi([60, 64], [0.0, 0.0], [0.5, 0.5])["notes"]) + cost = compute_event_cost(note_event, chord_event, gap_penalty=10) + assert cost == 10 + + +# 5. Tests for event_alignment_ED +# ------------------------------------------------------------------------------ +class TestEventAlignmentED(unittest.TestCase): def test_perfect_match_all_match_ops(self): ref = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) res = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) types = [op["type"] for op in operations] assert all(t == "match" for t in types) def test_missing_note_detected(self): # pitch 62 missing in response ref = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) - res = make_midi([60, 64], [0, 1.0], [0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) + res = make_midi([60, 64], [0, 1.0], [0.4, 0.4]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) types = [op["type"] for op in operations] assert "missing" in types def test_extra_note_detected(self): # extra pitch 62 in response - ref = make_midi([60, 64], [0, 1.0], [0.4, 0.4]) + ref = make_midi([60, 64], [0, 1.0], [0.4, 0.4]) res = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) types = [op["type"] for op in operations] assert "extra" in types def test_wrong_pitch_is_replacement(self): ref = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) res = make_midi([60, 65], [0, 0.5], [0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) replacements = [op for op in operations if op["type"] == "replacement"] assert len(replacements) == 1 + + def test_identical_chords_are_matched(self): + # Two identical C major chords + ref = make_midi([60, 64, 67, 60, 64, 67], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], [0.5] * 6) + res = make_midi([60, 64, 67, 60, 64, 67], [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], [0.5] * 6) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) + types = [op["type"] for op in operations] + assert all(t == "match" for t in types) + + def test_different_chords_are_replacements(self): + # C major ref, C-minor response (E replaced by Eb) + ref = make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5]) + res = make_midi([60, 63, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) + assert operations[0]["type"] == "replacement" + + def test_mix_note_and_chord_aligned_correctly(self): + # Reference: one single note (C4) then one C major chord. + # Response: matches exactly. + # Checks that note and chord events are both aligned as "match" in the same sequence. + ref = make_midi([60, 60, 64, 67], [0.0, 1.0, 1.0, 1.0], [0.4, 0.5, 0.5, 0.5]) + res = make_midi([60, 60, 64, 67], [0.0, 1.0, 1.0, 1.0], [0.4, 0.5, 0.5, 0.5]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) + types = [op["type"] for op in operations] + assert all(t == "match" for t in types) + # First event is a note, second is a chord + assert ref_events[0]["event_type"] == "note" + assert ref_events[1]["event_type"] == "chord" def test_ops_are_in_forward_order(self): ref = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) res = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) ref_indices = [op["reference_idx"] for op in operations if op["reference_idx"] is not None] assert ref_indices == sorted(ref_indices) def test_cost_matrix_shape(self): ref = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) - res = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) - assert D.shape == (len(res["notes"]) + 1, len(ref["notes"]) + 1) + res = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) + assert D.shape == (len(res_events) + 1, len(ref_events) + 1) -# 4. Tests for estimate_global_timing and estimate_global_duration_scale +# 6. Tests for estimate_global_timing and estimate_global_duration_scale # ------------------------------------------------------------------------------ -class TestGlobalEstimations: +class TestTrendEstimations(unittest.TestCase): def test_perfect_timing_scale_is_one(self): ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) res = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) - scale, offset = estimate_global_timing(operations, res["notes"], ref["notes"]) + ref_ev = group_notes_into_events(ref["notes"]) + res_ev = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_ev, ref_ev) + scale, offset = estimate_global_timing(operations, res_ev, ref_ev) assert abs(scale - 1.0) < 0.01 - - def test_slower_playing_scale_greater_than_one(self): + + def test_faster_playing_scale_less_than_one(self): ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) - res = make_midi([60, 62, 64, 65], [0, 0.6, 1.2, 1.8], [0.4] * 4) # 20% slower - operations, D = note_alignment_ED(res["notes"], ref["notes"]) - scale, offset = estimate_global_timing(operations, res["notes"], ref["notes"]) + res = make_midi([60, 62, 64, 65], [0, 0.4, 0.8, 1.2], [0.4] * 4) + ref_ev = group_notes_into_events(ref["notes"]) + res_ev = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_ev, ref_ev) + scale, offset = estimate_global_timing(operations, res_ev, ref_ev) + assert scale < 1.0 + + def test_slower_playing_with_mixed_note_and_chord(self): + # One note, one C major chord, then one note. + # The response is played consistently 20% slower. + # Global timing scale should therefore be greater than 1.0. + ref = make_midi( + [60, 60, 64, 67, 72], + [0.0, 1.0, 1.0, 1.0, 2.0], + [0.4, 0.5, 0.5, 0.5, 0.4] + ) + res = make_midi( + [60, 60, 64, 67, 72], + [0.0, 1.2, 1.2, 1.2, 2.4], + [0.4, 0.5, 0.5, 0.5, 0.4] + ) + ref_ev = group_notes_into_events(ref["notes"]) + res_ev = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_ev, ref_ev) + scale, offset = estimate_global_timing(operations, res_ev, ref_ev) assert scale > 1.0 def test_perfect_duration_scale_is_one(self): ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) res = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) - dur_scale = estimate_global_duration_scale(operations, res["notes"], ref["notes"]) + ref_ev = group_notes_into_events(ref["notes"]) + res_ev = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_ev, ref_ev) + dur_scale = estimate_global_duration_scale(operations, res_ev, ref_ev) assert abs(dur_scale - 1.0) < 0.01 - + + def test_longer_durations_scale_greater_than_one(self): + ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) + res = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.6] * 4) + ref_ev = group_notes_into_events(ref["notes"]) + res_ev = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_ev, ref_ev) + dur_scale = estimate_global_duration_scale(operations, res_ev, ref_ev) + assert dur_scale > 1.0 + def test_fewer_than_3_matched_returns_defaults(self): # Only 2 notes -- should return (1.0, 0.0) ref = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) res = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) - operations, D = note_alignment_ED(res["notes"], ref["notes"]) - scale, offset = estimate_global_timing(operations, res["notes"], ref["notes"]) - dur_scale = estimate_global_duration_scale(operations, res["notes"], ref["notes"]) + ref_events = group_notes_into_events(ref["notes"]) + res_events = group_notes_into_events(res["notes"]) + operations, D = event_alignment_ED(res_events, ref_events) + scale, offset = estimate_global_timing(operations, res_events, ref_events) + dur_scale = estimate_global_duration_scale(operations, res_events, ref_events) assert scale == 1.0 assert offset == 0.0 assert dur_scale == 1.0 -# 5. Tests for note_level_feedback and compute_stats +# 7. Tests for event_level_feedback and compute_stats # ------------------------------------------------------------------------------ -class TestComparePerformanceED: +class TestComparePerformanceED(unittest.TestCase): - def test_consistent_tempo_not_flagged_per_note(self): - """ - A student playing consistently 20% slower should NOT get - note-level timing warnings -- only a global tempo comment. - """ + def test_consistent_tempo_not_flagged_per_event(self): + # A student playing consistently 20% slower should NOT get + # event-level timing warnings -- only a global tempo comment. ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) res = make_midi([60, 62, 64, 65], [0, 0.6, 1.2, 1.8], [0.4] * 4) result = compare_performance_ED(res, ref) - for n in result.note_details: + for n in result.event_details: if n["operation_type"] in ("match", "replacement"): assert n["timing_correct"] == True def test_single_late_note_flagged(self): - """ - One note that is very late compared to the rest should be flagged - after the global trend is removed. - """ ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) - res = make_midi([60, 62, 64, 65], [0, 0.5, 1.8, 1.5], [0.4] * 4) # note 3 very late + res = make_midi([60, 62, 64, 65], [0, 0.5, 1.8, 1.5], [0.4] * 4) result = compare_performance_ED(res, ref) - flagged = [n for n in result.note_details if not n["timing_correct"]] + flagged = [n for n in result.event_details if not n["timing_correct"]] assert len(flagged) > 0 def test_pitch_error_recorded_correctly(self): ref = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) - res = make_midi([60, 65], [0, 0.5], [0.4, 0.4]) # 3 semitones off + res = make_midi([60, 65], [0, 0.5], [0.4, 0.4]) result = compare_performance_ED(res, ref) - replacements = [n for n in result.note_details if n["operation_type"] == "replacement"] + replacements = [n for n in result.event_details if n["operation_type"] == "replacement"] assert len(replacements) == 1 assert replacements[0]["pitch_diff"] == 3 assert replacements[0]["pitch_correct"] == False @@ -196,7 +422,7 @@ def test_all_correct_stats(self): assert result.stats["total_notes_missing"] == 0 assert result.stats["total_notes_extra"] == 0 assert result.stats["total_notes_wrong_pitch"] == 0 - assert result.stats["pitch_all_correct"] == True + assert result.stats["pitch_all_aligned_correct"] == True def test_missing_note_counted(self): ref = make_midi([60, 62, 64], [0, 0.5, 1.0], [0.4, 0.4, 0.4]) @@ -216,35 +442,86 @@ def test_total_notes_in_reference(self): result = compare_performance_ED(res, ref) assert result.stats["total_notes_in_reference"] == 3 + def test_identical_chords_is_correct(self): + # Two chords: C major at t=0, F-major at t=1 + midi = make_midi([60, 64, 67, 65, 69, 72], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.5] * 6) + result = compare_performance_ED(midi, midi) + assert result.is_correct == True + assert result.stats["total_chords_in_reference"] == 2 + assert result.stats["total_chords_correct"] == 2 + + def test_missing_chord_counted(self): + # Reference has C major and F-major; response only has C major + ref = make_midi([60, 64, 67, 65, 69, 72], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.5] * 6) + res = make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5] * 3) + result = compare_performance_ED(res, ref) + assert result.stats["total_chords_missing"] == 1 + + def test_extra_chord_counted(self): + ref = make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5] * 3) + res = make_midi([60, 64, 67, 65, 69, 72], + [0.0, 0.0, 0.0, 1.0, 1.0, 1.0], + [0.5] * 6) + result = compare_performance_ED(res, ref) + assert result.stats["total_chords_extra"] == 1 + + def test_imperfect_chord_counted(self): + # One pitch wrong in a three-note chord -> imperfect, not fully correct + ref = make_midi([60, 64, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5]) + res = make_midi([60, 63, 67], [0.0, 0.0, 0.0], [0.5, 0.5, 0.5]) + result = compare_performance_ED(res, ref) + assert result.stats["total_chords_imperfect"] == 1 + assert result.stats["total_chords_correct"] == 0 + + def test_stats_with_mix_note_chord(self): + # Reference: one single note then C major and F-major chords. + # Response: missing the F-major chord. Note is correct. + # Checks that chord stats are counted correctly even when notes are present. + ref = make_midi([60, 60, 64, 67, 65, 69, 72], + [0.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0], + [0.4, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]) + res = make_midi([60, 60, 64, 67], + [0.0, 1.0, 1.0, 1.0], + [0.4, 0.5, 0.5, 0.5]) + result = compare_performance_ED(res, ref) + assert result.stats["total_notes_missing"] == 0 + assert result.stats["total_chords_missing"] == 1 + assert result.stats["total_chords_correct"] == 1 + assert result.stats["total_notes_in_reference"] == 1 + assert result.stats["total_chords_in_reference"] == 2 + -# 6. Tests for evaluation_function (Lambda Feedback integration) +# 8. Tests for evaluation_function (Lambda Feedback integration) # ------------------------------------------------------------------------------ -class TestEvaluationFunction: +class TestEvaluationFunction(unittest.TestCase): """ the core logic is already covered by TestComparePerformanceED above. simple checks here to ensure the interface is working as expected. """ def test_perfect_performance_is_correct(self): - midi = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) + # Single note (C4) then one C major chord + midi = make_midi([60, 60, 64, 67], [0.0, 1.0, 1.0, 1.0], [0.4, 0.5, 0.5, 0.5]) result = evaluation_function(midi, midi, {}) assert result["is_correct"] == True def test_pitch_error_is_not_correct(self): - ref = make_midi([60, 62], [0, 0.5], [0.4, 0.4]) - res = make_midi([60, 65], [0, 0.5], [0.4, 0.4]) + ref = make_midi([60, 60, 64, 67], [0.0, 1.0, 1.0, 1.0], [0.4, 0.5, 0.5, 0.5]) + res = make_midi([60, 60, 63, 67], [0.0, 1.0, 1.0, 1.0], [0.4, 0.5, 0.5, 0.5]) result = evaluation_function(res, ref, {}) assert result["is_correct"] == False -# 7. Tests for parameter overrides +# 9. Tests for parameter overrides # ------------------------------------------------------------------------------ -class TestParamOverrides: +class TestParamOverrides(unittest.TestCase): def test_tight_timing_threshold_triggers_warning(self): - """ - A very strict timing threshold should flag a note that is slightly late. - Note 3 is late while others are on time, so the residual is detectable. - """ + # A very strict timing threshold should flag a note that is slightly late. + # Note 3 is late while others are on time, so the residual is detectable. ref = make_midi([60, 62, 64, 65], [0, 0.5, 1.0, 1.5], [0.4] * 4) res = make_midi([60, 62, 64, 65], [0, 0.5, 1.3, 1.5], [0.4] * 4) result = compare_performance_ED( @@ -263,4 +540,15 @@ def test_custom_gap_penalty_passed_through(self): result_lenient = compare_performance_ED(res, ref, gap_penalty=1) assert result_lenient.stats["total_notes_missing"] > 0 result_default = compare_performance_ED(res, ref) - assert result_default.stats["total_notes_missing"] == 0 \ No newline at end of file + assert result_default.stats["total_notes_missing"] == 0 + + def test_custom_chord_onset_window_affects_grouping(self): + # Two notes 80ms apart: grouped with 100ms window, separate with 50ms window + ref = make_midi([60, 64], [0.00, 0.08], [0.5, 0.5]) + res = make_midi([60, 64], [0.00, 0.08], [0.5, 0.5]) + # With 100ms window: grouped as one chord -> chord_count=1 + result_wide = compare_performance_ED(res, ref, chord_onset_window=0.10) + # With 50ms window: two separate notes -> chord_count=0 + result_narrow = compare_performance_ED(res, ref, chord_onset_window=0.05) + assert result_wide.stats["total_chords_in_reference"] == 1 + assert result_narrow.stats["total_notes_in_reference"] == 2 \ No newline at end of file From 9974e980af4e7215183e61de631a7ea42083c414 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Sat, 27 Jun 2026 22:09:39 +0100 Subject: [PATCH 7/8] fix Keyerror, check for typo --- evaluation_function/compare_MIDI.py | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/evaluation_function/compare_MIDI.py b/evaluation_function/compare_MIDI.py index bde6b46..98302a2 100644 --- a/evaluation_function/compare_MIDI.py +++ b/evaluation_function/compare_MIDI.py @@ -24,7 +24,7 @@ # Gap penalty: cost of leaving a note unaligned (insertion/deletion) DEFAULT_GAP_PENALTY = 6 -# Timing: |response_start - predicted_start| / IOI must be below this. +# Timing: |response_start - predicted_start| / Inter-Onset Interval (IOI) must be below this. # e.g. 0.20 means the start can be off by up to 20% of the inter-onset interval. TIMING_RELATIVE_THRESHOLD = 0.20 @@ -73,10 +73,7 @@ def identify_chord_name(notes): """ Identify the chord name (e.g. "C major", "A minor") from a list of notes by matching their pitch class set against CHORD_TEMPLATES. - - For each candidate root in the pitch class set, normalise all pitch classes - to start at 0 and check against each template. If no match is found, - returns "unknown chord". + If no match is found, returns "unknown chord". Args: notes: list of note dicts, each with a "pitch" key. @@ -95,20 +92,16 @@ def identify_chord_name(notes): return "unknown chord" - def compute_chord_accuracy(ref_notes, res_notes): """ - Compute the chord accuracy score A from McLeod & Rohit (2022), - + Compute the chord accuracy score A from (Devaney, n.d.): A = (C - I + |y|) / (2 * |y|) - where: C = |y ∩ y_hat| (correctly played pitch classes) - I = |y_hat - y| (extra pitch classes played) + I = |y_hat - y| (unexpected pitch classes played) |y| (number of pitch classes in the reference chord) - - A = 1.0 means perfectly correct. A = 0.0 means nothing correct and - many extra notes are played. + A = 1.0 means perfectly correct. + A = 0.0 means nothing correct and many unexpected notes are played. Args: ref_notes: list of note dicts for the reference chord. @@ -118,7 +111,7 @@ def compute_chord_accuracy(ref_notes, res_notes): accuracy: float in [0, 1] correct_pitches: sorted list of pitch class ints in both chords missing_pitches: sorted list of pitch class ints in ref - extra_pitches:sorted list of pitch class ints in response + extra_pitches: sorted list of pitch class ints in response """ ref_pcs = get_pitch_class_set(ref_notes) res_pcs = get_pitch_class_set(res_notes) @@ -226,7 +219,7 @@ def group_notes_into_events(notes, chord_onset_window=DEFAULT_CHORD_ONSET_WINDOW if len(notes) == 0: return [] - # Sort notes by start time first + # make sure notes are sorted by start time first sorted_notes = sorted(notes, key=lambda n: n["start"]) events = [] @@ -335,6 +328,13 @@ def event_alignment_ED(response_events, ref_events, gap_penalty=DEFAULT_GAP_PENA 'cost': int} D: accumulated cost matrix, shape (N+1, M+1) """ + # if a raw note dict with "pitch"/"start"/"duration" but no "event_type" is + # passed in, group them into events first. + if "event_type" not in response_events[0]: + response_events = group_notes_into_events(response_events) + if "event_type" not in ref_events[0]: + ref_events = group_notes_into_events(ref_events) + # the rows of D correspond to response events N = len(response_events) # the columns of D correspond to reference events From 129b897780652ce23c6bf122270bc942f4bc7545 Mon Sep 17 00:00:00 2001 From: ada-3e212e610b Date: Sat, 27 Jun 2026 22:10:14 +0100 Subject: [PATCH 8/8] add new parameter --- evaluation_function/evaluation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/evaluation_function/evaluation.py b/evaluation_function/evaluation.py index 176fdaa..3b5affe 100755 --- a/evaluation_function/evaluation.py +++ b/evaluation_function/evaluation.py @@ -17,6 +17,7 @@ DURATION_RELATIVE_THRESHOLD, GLOBAL_SLOW_THRESHOLD, GLOBAL_FAST_THRESHOLD, + DEFAULT_CHORD_ONSET_WINDOW ) @@ -66,6 +67,9 @@ def evaluation_function( global_fast_threshold=params.get( "global_fast_threshold", GLOBAL_FAST_THRESHOLD ), + chord_onset_window=params.get( + "chord_onset_window", DEFAULT_CHORD_ONSET_WINDOW + ) ) return {