Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
cmake_minimum_required(VERSION 3.21...4.0)
project(dmri-commit CXX C)
find_package(Python COMPONENTS Interpreter Development.Module NumPy REQUIRED)
include(UseCython)

set(CYTHON_ARGS "-3")

execute_process(
COMMAND "${Python_EXECUTABLE}" -c
"import sys; sys.path.insert(0,'${CMAKE_CURRENT_SOURCE_DIR}'); from setup_operator import write_operator_c_file; write_operator_c_file()"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}"
COMMAND_ERROR_IS_FATAL ANY
)

# trx-cpp via FetchContent
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
include(FetchContent)

set(TRX_BUILD_TESTS OFF)
set(TRX_BUILD_EXAMPLES OFF)
set(TRX_BUILD_DOCS OFF)
set(TRX_BUILD_BENCHMARKS OFF)
set(TRX_ENABLE_NIFTI OFF)
set(TRX_ENABLE_INSTALL OFF)

FetchContent_Declare(
trx-cpp
GIT_REPOSITORY https://github.com/tee-ar-ex/trx-cpp.git
GIT_TAG main
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(trx-cpp)

# json11.cpp is missing '#include <cstdint>'; force-include it
if(MSVC)
target_compile_options(trx PRIVATE /FIcstdint)
else()
target_compile_options(trx PRIVATE -include cstdint)
endif()

cython_transpile(commit/trk2dictionary/trk2dictionary.pyx LANGUAGE CXX OUTPUT_VARIABLE trk2dict_src)
Python_add_library(trk2dictionary MODULE "${trk2dict_src}" WITH_SOABI)
target_compile_features(trk2dictionary PRIVATE cxx_std_17)
target_compile_options(trk2dictionary PRIVATE $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-w>)
target_include_directories(trk2dictionary PRIVATE
"${CMAKE_CURRENT_SOURCE_DIR}/commit/trk2dictionary")
target_link_libraries(trk2dictionary PRIVATE trx-cpp::trx Python::NumPy)
install(TARGETS trk2dictionary DESTINATION commit)

cython_transpile(commit/core.pyx LANGUAGE CXX OUTPUT_VARIABLE core_src)
Python_add_library(core MODULE "${core_src}" WITH_SOABI)
target_compile_features(core PRIVATE cxx_std_11)
target_compile_options(core PRIVATE $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-w>)
target_link_libraries(core PRIVATE Python::NumPy)
install(TARGETS core DESTINATION commit)

cython_transpile(commit/proximals.pyx LANGUAGE CXX OUTPUT_VARIABLE proximals_src)
Python_add_library(proximals MODULE "${proximals_src}" WITH_SOABI)
target_compile_features(proximals PRIVATE cxx_std_11)
target_compile_options(proximals PRIVATE $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-w>)
target_link_libraries(proximals PRIVATE Python::NumPy)
install(TARGETS proximals DESTINATION commit)

cython_transpile(commit/models.pyx LANGUAGE CXX OUTPUT_VARIABLE models_src)
Python_add_library(models MODULE "${models_src}" WITH_SOABI)
target_compile_features(models PRIVATE cxx_std_11)
target_compile_options(models PRIVATE $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-w>)
target_link_libraries(models PRIVATE Python::NumPy)
install(TARGETS models DESTINATION commit)

cython_transpile(commit/operator/operator.pyx LANGUAGE C OUTPUT_VARIABLE operator_src)
Python_add_library(commit_operator MODULE
"${operator_src}"
"${CMAKE_CURRENT_SOURCE_DIR}/commit/operator/operator_c.c"
WITH_SOABI
)
set_target_properties(commit_operator PROPERTIES OUTPUT_NAME operator)
target_compile_options(commit_operator PRIVATE
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-w -O3 -Ofast>
$<$<CXX_COMPILER_ID:MSVC>:/fp:fast /DHAVE_STRUCT_TIMESPEC>
)
target_link_libraries(commit_operator PRIVATE Python::NumPy)
if(WIN32)
target_include_directories(commit_operator PRIVATE "$ENV{PTHREAD_WIN_INCLUDE}")
target_link_directories(commit_operator PRIVATE "$ENV{PTHREAD_WIN_LIB}")
target_link_libraries(commit_operator PRIVATE pthread)
endif()
install(TARGETS commit_operator DESTINATION commit/operator)
39 changes: 29 additions & 10 deletions commit/trk2dictionary/trk2dictionary.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import pickle
from importlib import metadata
import shutil
import time
from trx.io import load as load_trx

logger = setup_logger('trk2dictionary')

Expand Down Expand Up @@ -340,14 +341,28 @@ cpdef run( filename_tractogram=None, path_out=None, filename_peaks=None, filenam
if not exists(filename_tractogram):
logger.error( f'Tractogram file not found: {filename_tractogram}' )
extension = splitext(filename_tractogram)[1]
if extension != ".trk" and extension != ".tck":
logger.error( 'Invalid input file: only .trk and .tck are supported' )

hdr = nibabel.streamlines.load( filename_tractogram, lazy_load=True ).header


if extension == ".trk":
logger.subinfo ( f'geometry taken from "{filename_tractogram}"', indent_lvl=3, indent_char='-' )
if extension != ".trx" and extension != ".trk" and extension != ".tck":
logger.error( 'Invalid input file: only .trx, .trk, and .tck are supported')

if extension == ".trx":
hdr = load_trx(filename_tractogram, "same").header
affine = hdr["VOXEL_TO_RASMM"]
voxel_sizes = nibabel.affines.voxel_sizes(affine)

Nx = int(hdr['DIMENSIONS'][0])
Ny = int(hdr['DIMENSIONS'][1])
Nz = int(hdr['DIMENSIONS'][2])
Px = voxel_sizes[0]
Py = voxel_sizes[1]
Pz = voxel_sizes[2]

data_offset = 0 # stored separately in .trx
n_count = hdr['NB_STREAMLINES']
n_scalars = 0 # stored separately in .trx
n_properties = 0 # stored separately in .trx
elif extension == ".trk":
hdr = nibabel.streamlines.load( filename_tractogram, lazy_load=True ).header

Nx = int(hdr['dimensions'][0])
Ny = int(hdr['dimensions'][1])
Nz = int(hdr['dimensions'][2])
Expand All @@ -369,6 +384,7 @@ cpdef run( filename_tractogram=None, path_out=None, filename_peaks=None, filenam
else:
logger.error( 'TCK files do not contain information about the geometry. Use "TCK_ref_image" for that' )
logger.subinfo ( f'geometry taken from "{TCK_ref_image}"', indent_lvl=3, indent_char='-' )
hdr = nibabel.streamlines.load( filename_tractogram, lazy_load=True ).header

niiREF = nibabel.load( TCK_ref_image )
niiREF_hdr = _get_header( niiREF )
Expand Down Expand Up @@ -408,8 +424,11 @@ cpdef run( filename_tractogram=None, path_out=None, filename_peaks=None, filenam
# get toVOXMM matrix (remove voxel scaling from affine) in case of TCK
cdef float [:] toVOXMM
cdef float* ptrToVOXMM
if extension == ".tck":
M = _get_affine( niiREF ).copy()
if extension == ".tck" or extension == ".trx":
if extension == ".tck":
M = _get_affine( niiREF ).copy()
else:
M = np.asarray(affine, dtype=np.float64).copy()
# float64 conversion added to comply with the new cast policy of numpy v2
M[:3, :3] = M[:3, :3].dot( np.diag([np.float64(1)/Px,np.float64(1)/Py,np.float64(1)/Pz]) )
toVOXMM = np.ravel(np.linalg.inv(M)).astype('<f4')
Expand Down
136 changes: 90 additions & 46 deletions commit/trk2dictionary/trk2dictionary_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
#include <math.h>
#include <iostream>
#include <thread>
#include <functional>
#include <numeric>
#include <chrono>
#include <trx/trx.h>

#define _FILE_OFFSET_BITS 64
#define MAX_FIB_LEN 10000
Expand Down Expand Up @@ -88,12 +90,13 @@ int verbosity = 0;
bool rayBoxIntersection( Vector<double>& origin, Vector<double>& direction, Vector<double>& vmin, Vector<double>& vmax, double & t);
void fiberForwardModel( float fiber[3][MAX_FIB_LEN], unsigned int pts, int nReplicas, double* ptrBlurRho, double* ptrBlurAngle, double* ptrBlurWeights, bool doApplyBlur, short* ptrHashTable, vector<Vector<double>>& P );
void segmentForwardModel( const Vector<double>& P1, const Vector<double>& P2, int k, double w, short* ptrHashTable);
unsigned int read_fiberTRX(const trx::AnyTrxFile& fp, float fiber[3][MAX_FIB_LEN], int idx, float* ptrToVOXMM);
unsigned int read_fiberTRK( FILE* fp, float fiber[3][MAX_FIB_LEN], int ns, int np );
unsigned int read_fiberTCK( FILE* fp, float fiber[3][MAX_FIB_LEN] , float* toVOXMM );


// ---------- Parallel fuction --------------
int ICSegments( char* str_filename, int isTRK, int n_count, int nReplicas, int n_scalars, int n_properties, float* ptrToVOXMM,
int ICSegments( char* str_filename, const trx::AnyTrxFile& trxFile, int isTRX, int isTRK, int n_count, int nReplicas, int n_scalars, int n_properties, float* ptrToVOXMM,
double* ptrTDI , double* ptrBlurRho, double* ptrBlurAngle, double* ptrBlurWeights, bool* ptrBlurApplyTo, short* ptrHashTable, char* path_out,
unsigned long long int offset, int idx, unsigned int startpos, unsigned int endpos );

Expand Down Expand Up @@ -165,69 +168,81 @@ int trk2dictionary(
// Check the file extension
// -------------------------------------
int isTRK;
int isTRX;

char *ext = strrchr(str_filename, '.');
if (strcmp(ext,".trk")==0) // for .trk file
if (strcmp(ext,".trx")==0) { // for .trx file
isTRK = 0;
isTRX = 1;
} else if (strcmp(ext,".trk")==0) { // for .trk file
isTRK = 1;
else if (strcmp(ext,".tck")==0) // for .tck file
isTRX = 0;
} else if (strcmp(ext,".tck")==0) { // for .tck file
isTRK = 0;
else
return 0;
isTRX = 0;
} else return 0;


// Open tractogram file and compute the offset for each thread
// This is only needed for .trk and .tck files
// For trx, we can use the trx::AnyTrxFile class to read
// the streamlines in parallel
// -----------------------------------------------------------------
trx::AnyTrxFile trxFile;
unsigned long long int current;
unsigned long long int *OffsetArr = new unsigned long long int[threads_count]();
int f = 0;
float *Buff = new float[3]();
int N;

FILE* fpTractogram = fopen(str_filename,"rb");
if (fpTractogram == NULL) return 0;
fseek( fpTractogram, data_offset, SEEK_SET ); // skip the header

OffsetArr[0] = ftell( fpTractogram );

if(isTRK) {
while( f != n_count) {
fread( (char*)&N, 1, 4, fpTractogram ); // read the number of points in each streamline
if( N >= MAX_FIB_LEN || N <= 0 ) return 0; // check
for( int k=0; k<N; k++){
fread((char*)Buff, 1, 12, fpTractogram);
}
fseek(fpTractogram,4*n_properties,SEEK_CUR);
f++;
current = ftell( fpTractogram );
for( int i = 1; i < threads_count; i++ ){
if( f == Pos[i] )
OffsetArr[i] = current;
}
}
if (isTRX) {
trxFile = trx::AnyTrxFile::load(str_filename);
} else {
FILE* fpTractogram = fopen(str_filename,"rb");
if (fpTractogram == NULL) return 0;
fseek( fpTractogram, data_offset, SEEK_SET ); // skip the header

OffsetArr[0] = ftell( fpTractogram );

if(isTRK) {
while( f != n_count) {
fread( (char*)&N, 1, 4, fpTractogram ); // read the number of points in each streamline
if( N >= MAX_FIB_LEN || N <= 0 ) return 0; // check
for( int k=0; k<N; k++){
fread((char*)Buff, 1, 12, fpTractogram);
}
fseek(fpTractogram,4*n_properties,SEEK_CUR);
f++;
current = ftell( fpTractogram );
for( int i = 1; i < threads_count; i++ ){
if( f == Pos[i] )
OffsetArr[i] = current;
}
}
} else {

while( f != n_count ) {
while( f != n_count ) {

fread((char*)Buff, 1, 12, fpTractogram );
fread((char*)Buff, 1, 12, fpTractogram );

if( isnan(Buff[0]) ){
f++;
current = ftell( fpTractogram );
if( isnan(Buff[0]) ){
f++;
current = ftell( fpTractogram );

for( int i = 1; i < threads_count; i++ ){
if( f == Pos[i] )
OffsetArr[i] = current;
}

for( int i = 1; i < threads_count; i++ ){
if( f == Pos[i] )
OffsetArr[i] = current;
}

}

}

fclose(fpTractogram);
}

fclose(fpTractogram);



// ==========================================
// Parallel IC compartments
Expand All @@ -243,7 +258,7 @@ int trk2dictionary(
}
// ---- Original ------
for( int i = 0; i<threads_count; i++ ){
threads.push_back( thread( ICSegments, str_filename, isTRK, n_count, nReplicas, n_scalars, n_properties, ptrToVOXMM,
threads.push_back( thread( ICSegments, str_filename, std::ref(trxFile), isTRX, isTRK, n_count, nReplicas, n_scalars, n_properties, ptrToVOXMM,
ptrTDI[i] , ptrBlurRho, ptrBlurAngle, ptrBlurWeights, ptrBlurApplyTo, ptrHashTable, path_out, OffsetArr[i],
i, Pos[i], Pos[i+1] ) );
}
Expand Down Expand Up @@ -282,6 +297,9 @@ int trk2dictionary(
delete[] OffsetArr;
delete[] Buff;

if (isTRX)
trxFile.close();

return 1;
}

Expand Down Expand Up @@ -443,7 +461,7 @@ int ISOcompartments(double** ptrTDI, char* path_out, int threads)
/* Parallel Function */
/********************************************************************************************************************/

int ICSegments( char* str_filename, int isTRK, int n_count, int nReplicas, int n_scalars, int n_properties, float* ptrToVOXMM, double* ptrTDI, double* ptrBlurRho,
int ICSegments( char* str_filename, const trx::AnyTrxFile& trxFile, int isTRX, int isTRK, int n_count, int nReplicas, int n_scalars, int n_properties, float* ptrToVOXMM, double* ptrTDI, double* ptrBlurRho,
double* ptrBlurAngle, double* ptrBlurWeights, bool* ptrBlurApplyTo, short* ptrHashTable, char* path_out,
unsigned long long int offset, int idx, unsigned int startpos, unsigned int endpos )
{
Expand Down Expand Up @@ -486,9 +504,14 @@ unsigned long long int offset, int idx, unsigned int startpos, unsigned int endp

// ---- Original -----
// Open tractogram file
FILE* fpTractogram1 = fopen( str_filename,"rb" );
if ( fpTractogram1 == NULL ) return 0; // if there's no tractogram file, then return 0
fseek(fpTractogram1, offset, SEEK_SET);
FILE* fpTractogram1;

if ( !isTRX ) {
FILE* fp = fopen( str_filename, "rb" );
if ( fp == NULL ) return 0;
fseek(fp, offset, SEEK_SET);
fpTractogram1 = fp;
}

tempTotFibers = 0;
temp_totICSegments = 0;
Expand All @@ -499,11 +522,12 @@ unsigned long long int offset, int idx, unsigned int startpos, unsigned int endp
for(int f=startpos; f<endpos; f++)
{

if ( isTRK )
if ( isTRX )
N = read_fiberTRX( trxFile, fiber, idx, ptrToVOXMM );
else if ( isTRK )
N = read_fiberTRK( fpTractogram1, fiber, n_scalars, n_properties );
else
N = read_fiberTCK( fpTractogram1, fiber , ptrToVOXMM );

N = read_fiberTCK( fpTractogram1, fiber, ptrToVOXMM );

fiberForwardModel( fiber, N, nReplicas, ptrBlurRho, ptrBlurAngle, ptrBlurWeights, ptrBlurApplyTo[f], ptrHashTable, P );

Expand Down Expand Up @@ -564,7 +588,8 @@ unsigned long long int offset, int idx, unsigned int startpos, unsigned int endp
}
}
}
fclose( fpTractogram1 );
if ( !isTRX )
fclose( fpTractogram1 );
fclose( pDict_TRK_norm );
fclose( pDict_IC_f );
fclose( pDict_IC_v );
Expand Down Expand Up @@ -869,6 +894,25 @@ bool rayBoxIntersection( Vector<double>& origin, Vector<double>& direction, Vect
}


unsigned int read_fiberTRX(const trx::AnyTrxFile& fp, float fiber[3][MAX_FIB_LEN], int idx, float* ptrToVOXMM)
{
std::vector<std::array<double, 3>> streamline = fp.get_streamline(idx);
const unsigned int N = streamline.size();

if ( N >= MAX_FIB_LEN || N <= 0 )
return 0;

for (uint64_t i = 0; i < N; i++) {
const float x = static_cast<float>(streamline[i][0]);
const float y = static_cast<float>(streamline[i][1]);
const float z = static_cast<float>(streamline[i][2]);
fiber[0][i] = x * ptrToVOXMM[0] + y * ptrToVOXMM[1] + z * ptrToVOXMM[2] + ptrToVOXMM[3];
fiber[1][i] = x * ptrToVOXMM[4] + y * ptrToVOXMM[5] + z * ptrToVOXMM[6] + ptrToVOXMM[7];
fiber[2][i] = x * ptrToVOXMM[8] + y * ptrToVOXMM[9] + z * ptrToVOXMM[10] + ptrToVOXMM[11];
}
return N;
}

// Read a fiber from file .trk
unsigned int read_fiberTRK( FILE* fp, float fiber[3][MAX_FIB_LEN], int ns, int np )
{
Expand Down
Loading