feat(dpmodel): add backend-independent trainer abstraction#5603
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Repository UI Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (8)
🚧 Files skipped from review as they are similar to previous changes (7)
📝 WalkthroughWalkthroughAdds shared training abstractions for entrypoints, trainers, task normalization, and validation. Refactors JAX and pt_expt training flows to use the shared pipeline. Consolidates finetune rule handling into shared utilities. Extends JAX checkpoint serialization for multitask models. ChangesShared Training Abstractions and Backend Integrations
Estimated code review effort: 5 (Critical) | ~120 minutes Possibly related PRs
Suggested reviewers: 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
source/tests/test_dpmodel_abstract_trainer.py (1)
99-109: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAvoid hard-coding the default task key in this test.
TrainingTaskCollection.single()owns the default key viaDEFAULT_TASK_KEY, so indexing with"Default"makes this test fail on an internal rename without any behavior change. Pull the lone task from the collection API instead.Proposed fix
- task = tasks["Default"] + task = tasks.select() task.add_data_requirements() assert not tasks.is_multitask assert tasks.select() is task🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@source/tests/test_dpmodel_abstract_trainer.py` around lines 99 - 109, The test is hard-coding the default task name instead of using the collection API. In test_dpmodel_abstract_trainer.py, update the TrainingTaskCollection.single() usage so the lone task is retrieved without indexing by "Default", and reference the collection’s default-task behavior through its API/select() rather than the literal key. Keep the assertions on task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is renamed.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 380-383: Keep the single-task table schema anchored to
train_results in the Trainer formatting path. The header logic and the row
emission in the same train/eval summary should use the exact same key sequence
from train_results, not valid_results, so the validation columns stay aligned
and missing or reordered validation metrics do not break train_results lookups.
Update the loop in the Trainer row-building code to derive the metric order from
train_results and only read matching validation values by key when present.
- Around line 148-161: The task normalization in the trainer constructor
silently overwrites duplicate entries when `tasks` is a sequence, so add an
explicit duplicate-key check before building `self._tasks` in
`Trainer.__init__`. Update the branch that handles non-Mapping `tasks` to
validate that each `task.key` appears only once, and raise a clear `ValueError`
if duplicates are found. Keep the existing key-matching validation and
`_normalize_probabilities` flow intact.
In `@deepmd/jax/train/trainer.py`:
- Around line 188-194: The training setup in
AbstractTrainer.on_train_begin()/train() is missing registration of loss label
requirements before creating the task collection. Update the training flow
around TrainingTaskCollection.single and self.run(tasks) so
DeepmdDataSystem.add_data_requirements() is called with
self.loss.label_requirement (or equivalent task data_requirements) before
batching starts, ensuring get_batch() includes the labels needed by the loss
path.
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 53-64: The evaluate_training helper is consuming a batch for
inactive multitask entries by falling back to task.training_data.get_batch()
when step_result is None, which advances the cursor during display collection.
Update evaluate_training in test_dpmodel_abstract_trainer.py so it only reads
step_result.payload for the matching task and otherwise returns a non-consuming
placeholder/skip path for inactive tasks; keep the batch-order contract aligned
with the multitask loop in trainer.py and the TrainStepResult/task.key check.
---
Nitpick comments:
In `@source/tests/test_dpmodel_abstract_trainer.py`:
- Around line 99-109: The test is hard-coding the default task name instead of
using the collection API. In test_dpmodel_abstract_trainer.py, update the
TrainingTaskCollection.single() usage so the lone task is retrieved without
indexing by "Default", and reference the collection’s default-task behavior
through its API/select() rather than the literal key. Keep the assertions on
task.add_data_requirements(), tasks.is_multitask, and tasks.select() intact, but
bind task from the collection in a way that won’t break if DEFAULT_TASK_KEY is
renamed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ae088f30-bef5-4da4-ae02-40f16ee87975
📒 Files selected for processing (4)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/train/trainer.pysource/tests/test_dpmodel_abstract_trainer.py
d100673 to
6e168b7
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
deepmd/dpmodel/train/trainer.py (2)
422-428: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winSingle-task row schema should anchor to
train_results, notvalid_results.
format_headeriteratestrain_results(Line 382), but the row loop here iteratesvalid_resultsand indexestrain_results[key]. If a backend returns validation metrics in a different order or omits a metric, the row desynchronizes from the header andtrain_results[key]can raiseKeyError.Suggested fix
if valid_results is not None: assert not self._is_multitask(valid_results) - for key in valid_results: + for key in train_results: row += ( - f" {float(valid_results[key]):11.2e}" + f" {float(valid_results.get(key, float('nan'))):11.2e}" f" {float(train_results[key]):11.2e}" )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/train/trainer.py` around lines 422 - 428, The single-task row assembly in `Trainer.format_*` is anchored to the wrong metric source: the header follows `train_results`, but this row loop currently iterates `valid_results` while reading `train_results[key]`, which can desynchronize columns or fail when validation metrics differ. Update the row-building logic to use the same key order as `train_results` (matching the existing `format_header` contract) and only use `valid_results` for the optional validation value lookup, keeping the schema consistent with the training metrics.
148-161: 🎯 Functional Correctness | 🟠 Major | ⚡ Quick winDuplicate task keys are still silently dropped when
tasksis a sequence.
task_dict = {task.key: task for task in tasks}overwrites earlier entries, so a misconfigured multitask run can lose a task with no error. Validate uniqueness before building_tasks.Suggested fix
if isinstance(tasks, Mapping): task_dict = dict(tasks) else: - task_dict = {task.key: task for task in tasks} + task_list = list(tasks) + task_dict = {task.key: task for task in task_list} + if len(task_dict) != len(task_list): + raise ValueError("Training task keys must be unique.")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/dpmodel/train/trainer.py` around lines 148 - 161, The task normalization in Trainer.__init__ still silently overwrites duplicate keys when tasks is a sequence because task_dict is built directly from task.key; add an explicit duplicate-key check before assigning to self._tasks. In the branch that handles non-Mapping tasks, validate that each task.key is unique (raise a ValueError on duplicates) before constructing the dict, while keeping the existing key-vs-task.key consistency check and downstream _normalize_probabilities flow unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/pt_expt/train/training.py`:
- Line 831: Wire the parsed retention setting into checkpoint cleanup so
`max_ckpt_keep` actually limits saved checkpoints: `TrainerConfig` already
carries the value and `AbstractTrainer.run()`/`save_checkpoint()` currently only
append new `self.save_ckpt-<step>.pt` files. Update the checkpoint-saving flow
to track existing checkpoints and prune the oldest ones after each save, using
`self.max_ckpt_keep` as the cap and keeping the newest checkpoints only.
---
Duplicate comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 422-428: The single-task row assembly in `Trainer.format_*` is
anchored to the wrong metric source: the header follows `train_results`, but
this row loop currently iterates `valid_results` while reading
`train_results[key]`, which can desynchronize columns or fail when validation
metrics differ. Update the row-building logic to use the same key order as
`train_results` (matching the existing `format_header` contract) and only use
`valid_results` for the optional validation value lookup, keeping the schema
consistent with the training metrics.
- Around line 148-161: The task normalization in Trainer.__init__ still silently
overwrites duplicate keys when tasks is a sequence because task_dict is built
directly from task.key; add an explicit duplicate-key check before assigning to
self._tasks. In the branch that handles non-Mapping tasks, validate that each
task.key is unique (raise a ValueError on duplicates) before constructing the
dict, while keeping the existing key-vs-task.key consistency check and
downstream _normalize_probabilities flow unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 2ee2703c-74d9-48e0-8c5e-ce031e73c19a
📒 Files selected for processing (5)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/train/trainer.pydeepmd/pt_expt/train/training.pysource/tests/test_dpmodel_abstract_trainer.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/dpmodel/train/init.py
🚧 Files skipped from review as they are similar to previous changes (2)
- source/tests/test_dpmodel_abstract_trainer.py
- deepmd/jax/train/trainer.py
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #5603 +/- ##
==========================================
- Coverage 81.98% 81.77% -0.22%
==========================================
Files 959 973 +14
Lines 105430 107386 +1956
Branches 4071 4144 +73
==========================================
+ Hits 86442 87811 +1369
- Misses 17518 18057 +539
- Partials 1470 1518 +48 ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
6e168b7 to
78272a7
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/entrypoint.py`:
- Around line 79-83: The cleanup guard in the entrypoint flow is starting too
late, so `teardown_run()` will not run if `setup_run()` fails after partially
initializing backend state. Move the `try/finally` in the main training
entrypoint so it wraps `setup_run()` itself, ensuring `teardown_run()` is always
invoked even when setup raises. Use the existing `setup_run()` and
`teardown_run()` flow in this module as the anchor when making the change.
- Line 48: The AbstractTrainEntrypoint contract is too implicit: it currently
has no abstract methods and the hook methods are docstring-only, which triggers
lint issues. Update AbstractTrainEntrypoint so run_training() is marked
abstract, and give the optional hook methods explicit no-op bodies with return
None unless they must be overridden by every backend. Use the class and method
names AbstractTrainEntrypoint and run_training() to locate the contract and hook
definitions.
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 295-297: Replace the existing assert-based validation in the
multi-task config handling with an explicit runtime check that raises ValueError
when the config is "RANDOM"; this ensures the validation in the main entrypoint
is enforced even under optimized Python runs. Update the check in the entrypoint
logic that handles the multi-task configuration so the rejection of "RANDOM" is
always consistent.
- Around line 257-259: `setup_run` and `teardown_run` need ownership tracking
for the distributed process group: make `setup_run` skip `init_process_group()`
when a default group already exists, and record whether this entrypoint created
the group. Then update `teardown_run` to call `destroy_process_group()` only
when it is cleaning up a group created by `setup_run`, so it does not interfere
with a caller-managed distributed context.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: ec38d095-ecf4-4024-ac9f-7523f8d40e60
📒 Files selected for processing (11)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.pysource/tests/test_dpmodel_abstract_trainer.pysource/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (4)
- deepmd/dpmodel/train/init.py
- source/tests/test_dpmodel_abstract_trainer.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/train/training.py
78272a7 to
5bebfdf
Compare
There was a problem hiding this comment.
Actionable comments posted: 7
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/jax/utils/serialization.py (1)
189-207: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick winPreserve string task keys when converting restored state keys.
convert_str_to_int_key(state)also rewrites digit-only task names understate["models"]. A valid multi-task key like"0"becomes0, thenstate_by_model[model_key]fails becausemodel_def_script["model_dict"]still uses"0".Proposed fix
- convert_str_to_int_key(state) - model_def_script = data.model_def_script if "model_dict" in model_def_script: state_by_model = state.get("models", state) + if "models" in state: + for model_state in state_by_model.values(): + convert_str_to_int_key(model_state) + else: + convert_str_to_int_key(state_by_model) model_dict = {"model_dict": {}} for model_key, model_params in model_def_script["model_dict"].items(): abstract_model = get_model(model_params) graphdef, abstract_state = nnx.split(abstract_model) abstract_state.replace_by_pure_dict(state_by_model[model_key]) model = nnx.merge(graphdef, abstract_state) model_dict["model_dict"][model_key] = model.serialize() else: + convert_str_to_int_key(state) abstract_model = get_model(model_def_script)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/jax/utils/serialization.py` around lines 189 - 207, The state restoration logic in convert_str_to_int_key is too aggressive and converts digit-only task names inside state["models"], which breaks later lookup in the model merge path. Update the key conversion so it only normalizes keys that represent numeric indices where needed, while preserving string task keys used by model_def_script["model_dict"]; then ensure the state_by_model lookup in the model merging section continues to use the original task-name strings.deepmd/pt_expt/train/training.py (1)
1509-1515: 🗄️ Data Integrity & Integration | 🟠 Major | ⚡ Quick winMake the
latestsymlink target relative to its own directory.When
save_ckptincludes a directory, e.g.out/model.ckpt,latest.symlink_to("out/model.ckpt-1.pt")createsout/model.ckpt.pt -> out/model.ckpt-1.pt, which resolves asout/out/model.ckpt-1.pt. Restarting from the prefix then follows a broken symlink.Proposed fix
- ckpt_path = f"{self.save_ckpt}-{step}.pt" + ckpt_path = Path(f"{self.save_ckpt}-{step}.pt") torch.save(state, ckpt_path) # symlink latest latest = Path(f"{self.save_ckpt}.pt") if latest.is_symlink() or latest.exists(): latest.unlink() - latest.symlink_to(ckpt_path) + latest.symlink_to(ckpt_path.name)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@deepmd/pt_expt/train/training.py` around lines 1509 - 1515, The `latest` symlink in `training.py` is created with a target that can become incorrectly resolved when `save_ckpt` includes a directory. Update the checkpoint-saving logic in the block that builds `ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls `latest.symlink_to(...)` so the symlink target is computed relative to `latest`’s parent directory instead of using the full path string. Make sure the existing `torch.save` flow and cleanup of any prior `latest` link remain unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/data.py`:
- Around line 120-124: The fallback in _print_summary is too broad because it
catches every TypeError from data.print_summary and retries without prob, which
can hide real failures. Update _print_summary to detect the supported call
signature first (or only retry on an explicit argument-count mismatch) so only
old print_summary implementations use the no-prob path while genuine TypeError
exceptions still propagate.
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 575-579: The optional ABC hook methods on Training hooks are
docstring-only and trigger Ruff B027; update the on_train_begin and on_train_end
methods in TrainingTask/Trainer-related code to use explicit no-op bodies by
adding return None so they remain optional without failing lint.
- Around line 472-479: The cleanup guard starts too late in the training setup
path, so resources created by on_train_begin() may not be released if
_open_learning_curve() fails. Move the try/finally in Trainer.train() (or the
surrounding training entry point) so it begins before on_train_begin(tasks), and
keep on_train_end(tasks) in the finally block to ensure backend state is always
cleaned up even when setup throws.
In `@deepmd/dpmodel/utils/training_utils.py`:
- Around line 124-130: The `_training_data_size` helper is swallowing all
`TypeError` from `len(training_data)` and turning broken `__len__`
implementations into a fallback size of 1. Update `_training_data_size` so it
only returns 1 for objects that truly do not support sizing, and allow real
`__len__` failures to propagate instead of masking them; keep the existing
`get_nsystems` path and the `resolve_model_prob()` callers in mind when
adjusting the behavior.
In `@deepmd/jax/utils/finetune.py`:
- Around line 50-53: The single-task pretrained checkpoint path in finetune.py
silently ignores non-empty model_branch_from values other than RANDOM, which can
hide typos and fall back to Default unexpectedly. Update the branch-selection
logic in the finetuning flow around single_config_chosen/model_branch_from to
explicitly validate the branch name when from_multitask is false, and raise an
error for any unknown non-empty value instead of proceeding. Keep RANDOM as the
only special case that enables new_fitting, and ensure the check happens before
the code falls back to Default.
- Around line 152-157: The pre-check in finetune setup is rejecting valid
finetune_head aliases because it compares against raw pretrained_keys before
alias resolution. Update the validation around pretrained_key in the finetune
flow so it uses the same alias-aware mapping as _get_finetune_rule_single() via
get_model_dict(), and only raise the ValueError after checking the resolved
model dict entries rather than the unexpanded pretrained_keys list.
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 167-177: The _ensure_stat_file_path helper currently creates HDF5
stat files with h5py.File before ensuring the parent directory exists, so paths
like stats/model_stat.h5 can fail during trainer setup. Update
_ensure_stat_file_path to create the parent directories first for file targets,
then open the HDF5 file, and keep the existing directory creation path unchanged
for non-HDF5 stat targets. Use the stat_file_path, Path, and h5py.File logic in
this function to apply the fix.
---
Outside diff comments:
In `@deepmd/jax/utils/serialization.py`:
- Around line 189-207: The state restoration logic in convert_str_to_int_key is
too aggressive and converts digit-only task names inside state["models"], which
breaks later lookup in the model merge path. Update the key conversion so it
only normalizes keys that represent numeric indices where needed, while
preserving string task keys used by model_def_script["model_dict"]; then ensure
the state_by_model lookup in the model merging section continues to use the
original task-name strings.
In `@deepmd/pt_expt/train/training.py`:
- Around line 1509-1515: The `latest` symlink in `training.py` is created with a
target that can become incorrectly resolved when `save_ckpt` includes a
directory. Update the checkpoint-saving logic in the block that builds
`ckpt_path`, `latest = Path(f"{self.save_ckpt}.pt")`, and calls
`latest.symlink_to(...)` so the symlink target is computed relative to
`latest`’s parent directory instead of using the full path string. Make sure the
existing `torch.save` flow and cleanup of any prior `latest` link remain
unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 78180874-4d87-4c66-bd4e-fd0b4ee961bb
📒 Files selected for processing (15)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.pysource/tests/test_dpmodel_abstract_trainer.pysource/tests/test_dpmodel_train_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/dpmodel/train/init.py
5bebfdf to
9485193
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/jax/utils/finetune.py`:
- Around line 21-25: The suffix check in _load_model_params is too restrictive
and rejects checkpoint directories or pointer paths that serialize_from_file can
already deserialize. Remove the hard-coded .jax validation in _load_model_params
and let serialize_from_file(finetune_model) handle the input format consistently
with the freeze() and init_model paths, while still extracting
"model_def_script" from the loaded checkpoint data.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: cd8a6c26-b867-4429-892e-f4aa7e5fbc5e
📒 Files selected for processing (22)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pd/utils/finetune.pydeepmd/pt/utils/finetune.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/utils/finetune.pydeepmd/utils/finetune.pysource/tests/common/dpmodel/test_train_abstract_trainer.pysource/tests/common/dpmodel/test_train_data.pysource/tests/common/dpmodel/test_train_entrypoint.pysource/tests/common/dpmodel/test_training_utils.pysource/tests/common/test_finetune_utils.pysource/tests/jax/test_training.pysource/tests/pt_expt/test_entrypoint.py
🚧 Files skipped from review as they are similar to previous changes (8)
- deepmd/dpmodel/utils/training_utils.py
- deepmd/dpmodel/train/init.py
- deepmd/jax/utils/serialization.py
- deepmd/dpmodel/train/data.py
- deepmd/pt_expt/train/training.py
- deepmd/jax/entrypoints/train.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/entrypoints/main.py
5767485 to
7280c14
Compare
|
Pushed Changes:
Validation:
Both live GPU CLI runs produced |
|
@coderabbitai review |
✅ Action performedReview finished.
|
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@deepmd/dpmodel/train/trainer.py`:
- Around line 230-236: The mapping branch in trainer.py under the probability
handling logic currently validates missing keys but lets extra task probability
keys slip through. Update the probabilities check in the Mapping path to reject
unknown keys as well as missing ones, using self._keys as the only allowed set
before building the numpy array. Keep the fix localized to the probability
normalization/validation block so stale or mistyped entries raise a ValueError
instead of being silently ignored.
In `@deepmd/dpmodel/train/validation.py`:
- Around line 282-326: The validation flow in run() only raises errors on rank
0, so peer JAX processes can keep running after a rank-0 failure. Add a
backend-specific error propagation hook in the validator path around
_raise_if_error and/or the rank-0 try blocks, then implement the JAX version
used by the trainer to broadcast the failure or terminate all processes
consistently. Make sure the hook is invoked for failures in _evaluate,
save_checkpoint/_reconcile_best_checkpoints, and _log_result so non-zero ranks
do not continue silently.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 9857057b-3e07-4295-8a4a-6f044d0b4b1e
📒 Files selected for processing (28)
deepmd/dpmodel/train/__init__.pydeepmd/dpmodel/train/data.pydeepmd/dpmodel/train/entrypoint.pydeepmd/dpmodel/train/trainer.pydeepmd/dpmodel/train/validation.pydeepmd/dpmodel/utils/training_utils.pydeepmd/jax/entrypoints/train.pydeepmd/jax/train/trainer.pydeepmd/jax/train/validation.pydeepmd/jax/utils/finetune.pydeepmd/jax/utils/serialization.pydeepmd/pd/utils/finetune.pydeepmd/pt/train/validation.pydeepmd/pt/utils/finetune.pydeepmd/pt_expt/entrypoints/main.pydeepmd/pt_expt/train/training.pydeepmd/pt_expt/utils/finetune.pydeepmd/utils/argcheck.pydeepmd/utils/finetune.pysource/tests/common/dpmodel/test_train_abstract_trainer.pysource/tests/common/dpmodel/test_train_data.pysource/tests/common/dpmodel/test_train_entrypoint.pysource/tests/common/dpmodel/test_training_utils.pysource/tests/common/test_finetune_utils.pysource/tests/jax/test_training.pysource/tests/pt/test_validation.pysource/tests/pt_expt/test_entrypoint.pysource/tests/pt_expt/test_training.py
💤 Files with no reviewable changes (9)
- source/tests/common/dpmodel/test_train_data.py
- source/tests/common/dpmodel/test_training_utils.py
- source/tests/pt/test_validation.py
- source/tests/pt_expt/test_training.py
- source/tests/common/dpmodel/test_train_entrypoint.py
- source/tests/common/dpmodel/test_train_abstract_trainer.py
- source/tests/pt_expt/test_entrypoint.py
- source/tests/common/test_finetune_utils.py
- source/tests/jax/test_training.py
✅ Files skipped from review due to trivial changes (1)
- deepmd/utils/argcheck.py
🚧 Files skipped from review as they are similar to previous changes (12)
- deepmd/dpmodel/utils/training_utils.py
- deepmd/pt/utils/finetune.py
- deepmd/dpmodel/train/init.py
- deepmd/pd/utils/finetune.py
- deepmd/pt_expt/utils/finetune.py
- deepmd/dpmodel/train/data.py
- deepmd/jax/utils/serialization.py
- deepmd/jax/entrypoints/train.py
- deepmd/utils/finetune.py
- deepmd/jax/train/trainer.py
- deepmd/pt_expt/train/training.py
- deepmd/pt_expt/entrypoints/main.py
There was a problem hiding this comment.
Pull request overview
This PR introduces a backend-independent training abstraction layer under deepmd.dpmodel.train, and migrates JAX and pt_expt training/finetune flows to reuse shared orchestration (tasks/ranks normalization, learning-curve output, checkpoint cadence, lifecycle hooks, and full-validation best-checkpoint management). It also consolidates fine-tuning rule generation into a single backend-agnostic builder in deepmd.utils.finetune, leaving backend modules to focus on checkpoint loading and state/tensor copying.
Changes:
- Add backend-independent trainer/entrypoint/data/full-validation primitives in
deepmd.dpmodel.train, with associated test coverage moved/added undersource/tests/common/dpmodel/. - Centralize fine-tuning rule building in
deepmd.utils.finetuneand refactor PT/PT-exportable/Paddle/JAX finetune helpers to delegate to the shared builder. - Extend full-validation best-checkpoint management to support both file checkpoints (
.pt) and directory checkpoints (.jax), and migrate JAX +pt_expttraining entrypoints onto the shared pipeline.
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| source/tests/pt/test_validation.py | Adds coverage for best-checkpoint reconciliation when checkpoints are directories (e.g., .jax). |
| source/tests/pt_expt/test_training.py | Adds an integration-style test asserting pt_expt full validation writes/prunes best checkpoints and logs. |
| source/tests/pt_expt/test_entrypoint.py | New tests for pt_expt entrypoint option normalization, process-group ownership, checkpoint links/retention, and stat-file creation. |
| source/tests/jax/test_training.py | Adds tests for JAX optimizer LR scaling, finetune loading, state/key normalization, and entrypoint gating/multitask behavior. |
| source/tests/common/test_finetune_utils.py | Adds extensive tests for the shared finetune rule builder (single/multitask, aliases, random fitting, immutability). |
| source/tests/common/dpmodel/test_training_utils.py | New tests for model probability resolution and size fallbacks. |
| source/tests/common/dpmodel/test_train_entrypoint.py | New tests validating the shared entrypoint pipeline sequencing and teardown behavior on failures. |
| source/tests/common/dpmodel/test_train_data.py | New tests for shared data-summary printing compatibility/failure propagation. |
| source/tests/common/dpmodel/test_train_abstract_trainer.py | New tests for the backend-independent trainer loop, multitask sampling, lcurve formatting, checkpoint cadence, and full-validation ordering. |
| deepmd/utils/finetune.py | Introduces FinetuneRuleBuilder and shared finetune-rule construction APIs. |
| deepmd/utils/argcheck.py | Updates full-validation docs to indicate support across PT, pt_expt, and JAX. |
| deepmd/pt/utils/finetune.py | Refactors PyTorch finetune rules to delegate to shared finetune rule builder. |
| deepmd/pt/train/validation.py | Extends best-checkpoint handling to support configurable suffixes and directory checkpoints; improves validation-data iteration. |
| deepmd/pt_expt/utils/finetune.py | Refactors pt_expt finetune rules to delegate to shared finetune rule builder with backend-specific errors. |
| deepmd/pt_expt/train/training.py | Migrates pt_expt training loop to AbstractTrainer, adds task normalization, full validation hook, checkpoint retention, and relative latest-link handling. |
| deepmd/pt_expt/entrypoints/main.py | Migrates pt_expt train entrypoint to the shared entrypoint pipeline and shared data helpers; adds stat-file creation helper. |
| deepmd/pd/utils/finetune.py | Refactors Paddle finetune rules to delegate to shared finetune rule builder. |
| deepmd/jax/utils/serialization.py | Updates JAX serialization to support multitask state layout and avoid mis-normalizing numeric-looking task keys. |
| deepmd/jax/utils/finetune.py | New JAX finetune helper delegating to shared finetune rule builder. |
| deepmd/jax/train/validation.py | New JAX full-validation implementation built on backend-independent FullValidatorBase. |
| deepmd/jax/train/trainer.py | Migrates JAX trainer to AbstractTrainer, adds multitask support, shared finetune integration, full validation hook, and checkpoint writing changes. |
| deepmd/jax/entrypoints/train.py | Migrates JAX train entrypoint to the shared entrypoint pipeline and shared data helpers; adds multitask-aware neighbor-stat update. |
| deepmd/dpmodel/utils/training_utils.py | Improves model-probability resolution by handling get_nsystems() and non-sized data sources. |
| deepmd/dpmodel/train/validation.py | New backend-independent full-validation base with best-checkpoint management and val.log formatting. |
| deepmd/dpmodel/train/trainer.py | New backend-independent trainer abstraction handling task selection, display scheduling, lcurve writing, and checkpoint cadence. |
| deepmd/dpmodel/train/entrypoint.py | New backend-independent entrypoint orchestration pipeline used by backend-specific entrypoints. |
| deepmd/dpmodel/train/data.py | New shared data/task config normalization utilities and summary-printing compatibility helper. |
| deepmd/dpmodel/train/init.py | Exposes the new backend-independent training abstractions as a package API. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@coderabbitai review |
✅ Action performedReview finished.
|
Summary
deepmd.dpmodel.trainfor task/rank normalization, display scheduling, learning-curve output, checkpoint cadence, lifecycle hooks, and shared train entrypoint orchestration.deepmd.utils.finetune, and reduce the PT, PT-exportable, Paddle, and JAX backend finetune modules to backend-specific checkpoint loading plus shared rule generation.pt_expttrain entrypoint/trainer behavior further onto the shared pipeline, including single-task-as-multi-task normalization, data summaries, checkpoint retention, stat-file parent creation, relative latest checkpoint symlinks, and checkpoint parent creation.print_summaryfallback behavior, broken__len__handling, JAX finetune branch/alias validation, numeric-looking JAX task keys, HDF5 stat paths, andpt_exptcheckpoint symlinks.source/tests/test_dpmodel_*.pyintosource/tests/common/dpmodel/.Refs #5229, #5230, #5231
Tests
ruff format .ruff check .git diff --checkPYTHONPATH=/home/jzzeng/codes/deepmd-kit pytest source/tests/common/dpmodel/test_train_abstract_trainer.py source/tests/common/dpmodel/test_train_entrypoint.py source/tests/common/dpmodel/test_train_data.py source/tests/common/dpmodel/test_training_utils.py source/tests/common/test_finetune_utils.py source/tests/jax/test_training.py source/tests/pt_expt/test_entrypoint.py source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_from_single_task source/tests/pt_expt/test_multitask.py::TestMultiTaskSeA::test_multitask_finetune_no_change_model_params -q(53 passed, 2 subtests passed)PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --jax train input.json --skip-neighbor-stat --finetune pretrain.jax --use-pretrain-scripton a temporary 1-step water finetune smoke; completed on NVIDIA GeForce RTX 5090 and savedft-model-1.jax.PYTHONPATH=/home/jzzeng/codes/deepmd-kit timeout 180 srun --gres=gpu:1 dp --pt-expt train input.json --skip-neighbor-staton a temporary 2-step water smoke; completed on NVIDIA GeForce RTX 5090, savedckpts/pt-model-2.pt, createdstats/stat.hdf5, and verifiedckpts/pt-model.pt -> pt-model-2.ptwith old step checkpoint pruned bymax_ckpt_keep=1.Notes
paddleis not installed in this environment.deepmd_gnn/CUDA initialization, not by the shared finetune rule builder changes.Summary by CodeRabbit
New Features
val.logreporting (including backend-specific checkpoint suffixes).Bug Fixes
Tests
Documentation